RuntimeError: indices should be either on cpu or on the same device as the indexed tensor (cuda:0)

Greetings,

I am writing a diffusion model. I am also using the pytorch-lightning framework. I have 2 GPUs. Somehow I encountered this error.

All distributed processes registered. Starting with 2 processes
----------------------------------------------------------------------------------------------------

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]
LOCAL_RANK: 1 - CUDA_VISIBLE_DEVICES: [0,1]
/home/xz479/miniconda3/envs/delibot/lib/python3.8/site-packages/transformers/optimization.py:407: FutureWarning: This implementation of AdamW is deprecated and will be removed in a future version. Use the PyTorch implementation torch.optim.AdamW instead, or set `no_deprecation_warning=True` to disable this warning
  warnings.warn(
/home/xz479/miniconda3/envs/delibot/lib/python3.8/site-packages/transformers/optimization.py:407: FutureWarning: This implementation of AdamW is deprecated and will be removed in a future version. Use the PyTorch implementation torch.optim.AdamW instead, or set `no_deprecation_warning=True` to disable this warning
  warnings.warn(

  | Name  | Type           | Params
-----------------------------------------
0 | model | DiffusionModel | 72.4 M
-----------------------------------------
68.1 M    Trainable params
4.2 M     Non-trainable params
72.4 M    Total params
289.470   Total estimated model params size (MB)
Sanity Checking: |                                                                                                                  | 0/? [00:00<?, ?it/s]/home/xz479/miniconda3/envs/delibot/lib/python3.8/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=7` in the `DataLoader` to improve performance.
1
Traceback (most recent call last):
  File "/home/xz479/DiffuBot/train.py", line 148, in <module>
    train(args)
  File "/home/xz479/DiffuBot/train.py", line 134, in train
    trainer.fit(task, train_loader, val_loader)
  File "/home/xz479/miniconda3/envs/delibot/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 545, in fit
    call._call_and_handle_interrupt(
  File "/home/xz479/miniconda3/envs/delibot/lib/python3.8/site-packages/pytorch_lightning/trainer/call.py", line 43, in _call_and_handle_interrupt
    return trainer.strategy.launcher.launch(trainer_fn, *args, trainer=trainer, **kwargs)
  File "/home/xz479/miniconda3/envs/delibot/lib/python3.8/site-packages/pytorch_lightning/strategies/launchers/subprocess_script.py", line 102, in launch
    return function(*args, **kwargs)
  File "/home/xz479/miniconda3/envs/delibot/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 581, in _fit_impl
    self._run(model, ckpt_path=ckpt_path)
  File "/home/xz479/miniconda3/envs/delibot/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 990, in _run
    results = self._run_stage()
  File "/home/xz479/miniconda3/envs/delibot/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1034, in _run_stage
    self._run_sanity_check()
  File "/home/xz479/miniconda3/envs/delibot/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1063, in _run_sanity_check
    val_loop.run()
  File "/home/xz479/miniconda3/envs/delibot/lib/python3.8/site-packages/pytorch_lightning/loops/utilities.py", line 181, in _decorator
    return loop_run(self, *args, **kwargs)
  File "/home/xz479/miniconda3/envs/delibot/lib/python3.8/site-packages/pytorch_lightning/loops/evaluation_loop.py", line 134, in run
    self._evaluation_step(batch, batch_idx, dataloader_idx, dataloader_iter)
  File "/home/xz479/miniconda3/envs/delibot/lib/python3.8/site-packages/pytorch_lightning/loops/evaluation_loop.py", line 391, in _evaluation_step
    output = call._call_strategy_hook(trainer, hook_name, *step_args)
  File "/home/xz479/miniconda3/envs/delibot/lib/python3.8/site-packages/pytorch_lightning/trainer/call.py", line 309, in _call_strategy_hook
    output = fn(*args, **kwargs)
  File "/home/xz479/miniconda3/envs/delibot/lib/python3.8/site-packages/pytorch_lightning/strategies/strategy.py", line 402, in validation_step
    return self._forward_redirection(self.model, self.lightning_module, "validation_step", *args, **kwargs)
  File "/home/xz479/miniconda3/envs/delibot/lib/python3.8/site-packages/pytorch_lightning/strategies/strategy.py", line 628, in __call__
    wrapper_output = wrapper_module(*args, **kwargs)
  File "/home/xz479/miniconda3/envs/delibot/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/xz479/miniconda3/envs/delibot/lib/python3.8/site-packages/torch/nn/parallel/distributed.py", line 1156, in forward
    output = self._run_ddp_forward(*inputs, **kwargs)
  File "/home/xz479/miniconda3/envs/delibot/lib/python3.8/site-packages/torch/nn/parallel/distributed.py", line 1110, in _run_ddp_forward
    return module_to_run(*inputs[0], **kwargs[0])  # type: ignore[index]
  File "/home/xz479/miniconda3/envs/delibot/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/xz479/miniconda3/envs/delibot/lib/python3.8/site-packages/pytorch_lightning/strategies/strategy.py", line 621, in wrapped_forward
    out = method(*_args, **_kwargs)
  File "/home/xz479/DiffuBot/train.py", line 91, in validation_step
    loss = self.model.loss_fn(batch['dst_matrix'])
  File "/home/xz479/DiffuBot/train.py", line 42, in loss_fn
    output, epsilon, alpha_bar = self.forward(x, idx=idx, get_target=True)
  File "/home/xz479/DiffuBot/train.py", line 54, in forward
    used_alpha_bars = self.alpha_bars[idx][:, None, None]
RuntimeError: indices should be either on cpu or on the same device as the indexed tensor (cuda:0)
Sanity Checking DataLoader 0:   0%|                                                                                                 | 0/2 [00:00<?, ?it/s]0
0
0
torch.Size([])
Sanity Checking DataLoader 0:  50%|ā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–Œ                                            | 1/2 [00:00<00:00, 51.59it/s]0
0
0
torch.Size([])
Sanity Checking DataLoader 0: 100%|ā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆ| 2/2 [00:03<00:00,  0.55it/s]/home/xz479/miniconda3/envs/delibot/lib/python3.8/site-packages/pytorch_lightning/trainer/connectors/logger_connector/result.py:429: It is recommended to use `self.log('validation_epoch_average', ..., sync_dist=True)` when logging on epoch level in distributed setting to accumulate the metric across devices.
[rank: 1] Child process with PID 1109548 terminated with code 1. Forcefully terminating all other processes to avoid zombies šŸ§Ÿ                           
Killed

My code is below

class DiffusionModel(nn.Module):
    
    def __init__(self, device, beta_1, beta_T, T, decoder =None):
        '''
        The epsilon predictor of diffusion process.

        beta_1    : beta_1 of diffusion process
        beta_T    : beta_T of diffusion process
        T         : Diffusion Steps
        input_dim : a dimension of data

        '''

        super().__init__()
        self.device = device
        self.alpha_bars = torch.cumprod(1 - torch.linspace(start = beta_1, end=beta_T, steps=T), dim = 0).to(device = self.device)
        self.backbone = Transformer_Denoiser(T, batch_first=True).to(device = self.device)
        
        self.to(device = self.device)
    
    def loss_fn(self, x, idx=None):
        '''
        This function performed when only training phase.

        x          : real data if idx==None else perturbation data
        idx        : if None (training phase), we perturbed random index. Else (inference phase), it is recommended that you specify.

        '''
        output, epsilon, alpha_bar = self.forward(x, idx=idx, get_target=True)
        loss = (output - epsilon).square().mean()
        print(loss.shape)
        return loss
        
    def forward(self, x, idx=None, get_target=False):
        # print(x.shape)
        if idx == None:
            ### Generate random timestep index for noise sampling
            #idx_device = self.alpha_bars.get_device()
            idx = torch.randint(0, len(self.alpha_bars), (x.size(0), )).to(self.device)
            print(idx.get_device())
            used_alpha_bars = self.alpha_bars[idx][:, None, None]
            print(self.alpha_bars.get_device())
            print(used_alpha_bars.get_device())
            epsilon = torch.randn_like(x).to(self.device)
            #TODO lock noise with respect to turns
            x_tilde = torch.sqrt(used_alpha_bars) * x + torch.sqrt(1 - used_alpha_bars) * epsilon
            
        else:
            idx = torch.Tensor([idx for _ in range(x.size(0))]).long()
            x_tilde = x
        
        # print(x_tilde.shape)
        #return idx
        output = self.backbone(x_tilde, idx)
        
        return (output, epsilon, used_alpha_bars) if get_target else output

I am pretty sure all the tensors are on the same device as in the code I print out the devices of the alpha_bar tensor and the index tensor, and as shown in the error message, they are all on the device 0.

Iā€™d be really appreciated if someone could help.

Thank you very much.

1 Like

Iā€™m facing the same issue when using pytorch lightning 2.1.1 with ddp strategy and wandb logger. Could not find a workaround or a solution.
please help, Thanks.

Are you seeing these issues without PyTorch lightning in a pure PyTorch model?

I cant tell, since my way to use distributed training is by using pytorch lightning.
If I train on a single GPU, then there are no problems.
But, when I set devices=4, then I recieve the error:

ā€œ[rank: 2] child process with pid 4049 terminated with code -6. Forcefully terminating all other processes to avoid zombies :zombie:ā€

Thanks