Data Parallelization meet RuntimeError: copy_if failed to synchronize: cudaErrorAssert: device-side assert triggered

Hi,

I am trying to parallelize my loss module:

class Supervised_EmbeddingLoss(nn.Module):
    def __init__(self, n_cluster, cos_margin, push_balance_weight, pull_momentum, push_momentum):
        super(Supervised_EmbeddingLoss, self).__init__()
        ...
    def forward(self, pred, lbl, centroids, push_minibatch, pad, ignore=255):
        N, D, H, W = pred.size()
        pred = pred[:, :, pad:H - pad, pad:W - pad]
        lbl = lbl[:, pad:H - pad, pad:W - pad]

        pred = pred.transpose(1, 2).transpose(2, 3).contiguous().view(-1, D)
        lbl = lbl.contiguous().view(-1)
        pred = pred[lbl != ignore]
        lbl = lbl[lbl != ignore]

        pull_loss = F.cosine_embedding_loss(pred, centroids[lbl], margin=self.cos_margin,
                                            target=torch.ones_like(lbl, dtype=torch.float), reduction='mean')
        push_loss = 0.
        for i in range(self.n_cluster):
            _pred = pred[lbl != i]
            if push_minibatch != -1:
                r = torch.rand(size=[_pred.size()[0]])
                rinds = torch.argsort(r)[:push_minibatch]
                _pred = _pred[rinds]
            push_loss += F.cosine_embedding_loss(_pred, centroids[i].view(1, D), margin=self.cos_margin,
                                                 target=-1 * torch.ones((_pred.size()[0]), dtype=torch.float).cuda(),
                                                 reduction='mean')
        embed_loss = pull_loss + self.push_balance_weight * push_loss
        return {'embed_loss': embed_loss, 'pull_loss': pull_loss, 'push_loss': push_loss}

When I was training my model along with this loss function on one GPU it works fine. However, when I wrap my loss module using torch.nn.DataParallel() and trying to parallelize it on two GPUs,I got an error like:

...
...
/pytorch/aten/src/ATen/native/cuda/IndexKernel.cu:60: lambda [](int)->auto::operator()(int)->auto: block: [460,0,0], thread: [12,0,0] Assertion `index >= -sizes[i] && index < sizes[i] && "index out of bounds"` failed.
/pytorch/aten/src/ATen/native/cuda/IndexKernel.cu:60: lambda [](int)->auto::operator()(int)->auto: block: [460,0,0], thread: [13,0,0] Assertion `index >= -sizes[i] && index < sizes[i] && "index out of bounds"` failed.
/pytorch/aten/src/ATen/native/cuda/IndexKernel.cu:60: lambda [](int)->auto::operator()(int)->auto: block: [460,0,0], thread: [14,0,0] Assertion `index >= -sizes[i] && index < sizes[i] && "index out of bounds"` failed.
/pytorch/aten/src/ATen/native/cuda/IndexKernel.cu:60: lambda [](int)->auto::operator()(int)->auto: block: [460,0,0], thread: [15,0,0] Assertion `index >= -sizes[i] && index < sizes[i] && "index out of bounds"` failed.
/pytorch/aten/src/ATen/native/cuda/IndexKernel.cu:60: lambda [](int)->auto::operator()(int)->auto: block: [460,0,0], thread: [16,0,0] Assertion `index >= -sizes[i] && index < sizes[i] && "index out of bounds"` failed.
/pytorch/aten/src/ATen/native/cuda/IndexKernel.cu:60: lambda [](int)->auto::operator()(int)->auto: block: [460,0,0], thread: [17,0,0] Assertion `index >= -sizes[i] && index < sizes[i] && "index out of bounds"` failed.
/pytorch/aten/src/ATen/native/cuda/IndexKernel.cu:60: lambda [](int)->auto::operator()(int)->auto: block: [460,0,0], thread: [18,0,0] Assertion `index >= -sizes[i] && index < sizes[i] && "index out of bounds"` failed.
/pytorch/aten/src/ATen/native/cuda/IndexKernel.cu:60: lambda [](int)->auto::operator()(int)->auto: block: [460,0,0], thread: [19,0,0] Assertion `index >= -sizes[i] && index < sizes[i] && "index out of bounds"` failed.
/pytorch/aten/src/ATen/native/cuda/IndexKernel.cu:60: lambda [](int)->auto::operator()(int)->auto: block: [460,0,0], thread: [20,0,0] Assertion `index >= -sizes[i] && index < sizes[i] && "index out of bounds"` failed.
/pytorch/aten/src/ATen/native/cuda/IndexKernel.cu:60: lambda [](int)->auto::operator()(int)->auto: block: [460,0,0], thread: [21,0,0] Assertion `index >= -sizes[i] && index < sizes[i] && "index out of bounds"` failed.
/pytorch/aten/src/ATen/native/cuda/IndexKernel.cu:60: lambda [](int)->auto::operator()(int)->auto: block: [460,0,0], thread: [22,0,0] Assertion `index >= -sizes[i] && index < sizes[i] && "index out of bounds"` failed.
/pytorch/aten/src/ATen/native/cuda/IndexKernel.cu:60: lambda [](int)->auto::operator()(int)->auto: block: [460,0,0], thread: [23,0,0] Assertion `index >= -sizes[i] && index < sizes[i] && "index out of bounds"` failed.
/pytorch/aten/src/ATen/native/cuda/IndexKernel.cu:60: lambda [](int)->auto::operator()(int)->auto: block: [460,0,0], thread: [24,0,0] Assertion `index >= -sizes[i] && index < sizes[i] && "index out of bounds"` failed.
/pytorch/aten/src/ATen/native/cuda/IndexKernel.cu:60: lambda [](int)->auto::operator()(int)->auto: block: [460,0,0], thread: [25,0,0] Assertion `index >= -sizes[i] && index < sizes[i] && "index out of bounds"` failed.
/pytorch/aten/src/ATen/native/cuda/IndexKernel.cu:60: lambda [](int)->auto::operator()(int)->auto: block: [460,0,0], thread: [26,0,0] Assertion `index >= -sizes[i] && index < sizes[i] && "index out of bounds"` failed.
/pytorch/aten/src/ATen/native/cuda/IndexKernel.cu:60: lambda [](int)->auto::operator()(int)->auto: block: [460,0,0], thread: [27,0,0] Assertion `index >= -sizes[i] && index < sizes[i] && "index out of bounds"` failed.
/pytorch/aten/src/ATen/native/cuda/IndexKernel.cu:60: lambda [](int)->auto::operator()(int)->auto: block: [460,0,0], thread: [28,0,0] Assertion `index >= -sizes[i] && index < sizes[i] && "index out of bounds"` failed.
/pytorch/aten/src/ATen/native/cuda/IndexKernel.cu:60: lambda [](int)->auto::operator()(int)->auto: block: [460,0,0], thread: [29,0,0] Assertion `index >= -sizes[i] && index < sizes[i] && "index out of bounds"` failed.
/pytorch/aten/src/ATen/native/cuda/IndexKernel.cu:60: lambda [](int)->auto::operator()(int)->auto: block: [460,0,0], thread: [30,0,0] Assertion `index >= -sizes[i] && index < sizes[i] && "index out of bounds"` failed.
/pytorch/aten/src/ATen/native/cuda/IndexKernel.cu:60: lambda [](int)->auto::operator()(int)->auto: block: [460,0,0], thread: [31,0,0] Assertion `index >= -sizes[i] && index < sizes[i] && "index out of bounds"` failed.

Traceback (most recent call last):                                              
  File "train_3/train_0_0_thresh_0_6_mo_1_batch_2_test.py", line 199, in <module>
    trainer.train()
  File "/projectnb/cs585/kai/mt_clust/trainer/mt_trainer.py", line 253, in train
    self._train_epoch()
  File "/projectnb/cs585/kai/mt_clust/trainer/mt_trainer.py", line 102, in _train_epoch
    src_embed_loss = self.supervised_loss(src_pred_embed, source_lbl, centroids, self.cfg.push_minibatch, self.cfg.loss_pad)
  File "/share/pkg.7/pytorch/1.3/install/3.6/lib/python3.6/site-packages/torch/nn/modules/module.py", line 541, in __call__
    result = self.forward(*input, **kwargs)
  File "/share/pkg.7/pytorch/1.3/install/3.6/lib/python3.6/site-packages/torch/nn/parallel/data_parallel.py", line 152, in forward
    outputs = self.parallel_apply(replicas, inputs, kwargs)
  File "/share/pkg.7/pytorch/1.3/install/3.6/lib/python3.6/site-packages/torch/nn/parallel/data_parallel.py", line 162, in parallel_apply
    return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)])
  File "/share/pkg.7/pytorch/1.3/install/3.6/lib/python3.6/site-packages/torch/nn/parallel/parallel_apply.py", line 85, in parallel_apply
    output.reraise()
  File "/share/pkg.7/pytorch/1.3/install/3.6/lib/python3.6/site-packages/torch/_utils.py", line 385, in reraise
    raise self.exc_type(msg)
RuntimeError: Caught RuntimeError in replica 0 on device 0.
Original Traceback (most recent call last):
  File "/share/pkg.7/pytorch/1.3/install/3.6/lib/python3.6/site-packages/torch/nn/parallel/parallel_apply.py", line 60, in _worker
    output = module(*input, **kwargs)
  File "/share/pkg.7/pytorch/1.3/install/3.6/lib/python3.6/site-packages/torch/nn/modules/module.py", line 541, in __call__
    result = self.forward(*input, **kwargs)
  File "/projectnb/cs585/kai/mt_clust/model/losses/loss_embed.py", line 34, in forward
    _pred = pred[lbl != i]
RuntimeError: copy_if failed to synchronize: cudaErrorAssert: device-side assert triggered

It seems like it’s a problem on the indexing line: _pred = pred[lbl != i]

Can someone give me some clue about this weird problem?

Could you print the shape of pred as well as the result of lbl != i for the failing operation?
This should give you an idea, why the index is out of bounds.

I print their size out and their size is:

torch.Size([210226, 128])

torch.Size([235481, 128])

torch.Size([210226])

torch.Size([235481])

noting that the computation is parallelized on two GPUs, so the pred and lbl size on each GPU is:
torch.Size([210226, 128])
torch.Size([210226])

and

torch.Size([235481, 128])
torch.Size([235481])

The most confusing thing is when I cancel the parallelization and run it on a single GPU, it works perfectly.

Thanks for the shapes. Could you check, if pred might be empty after the

pred = pred[lbl != ignore]

operation?
Based on your description I guess you might run into a scenario, where one device only gets ingore indices, and this indexing pred again might yield this error.
You could check the pred.size(0) for a 0 after the first indexing.

I ran the code and it returned:

236914


235627

for two tensors. So they are not empty.

Thanks

@Kaiwkh, were you able to fix this issue? I am facing a similar issue where the loss work perfectly in single GPU , but with multiple GPUs (single node) facing the same issue as mentioned