RuntimeError: indices should be either on cpu or on the same device as the indexed tensor (cuda:0)

Greetings,

I am writing a diffusion model. I am also using the pytorch-lightning framework. I have 2 GPUs. Somehow I encountered this error.

All distributed processes registered. Starting with 2 processes
----------------------------------------------------------------------------------------------------

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]
LOCAL_RANK: 1 - CUDA_VISIBLE_DEVICES: [0,1]
/home/xz479/miniconda3/envs/delibot/lib/python3.8/site-packages/transformers/optimization.py:407: FutureWarning: This implementation of AdamW is deprecated and will be removed in a future version. Use the PyTorch implementation torch.optim.AdamW instead, or set `no_deprecation_warning=True` to disable this warning
  warnings.warn(
/home/xz479/miniconda3/envs/delibot/lib/python3.8/site-packages/transformers/optimization.py:407: FutureWarning: This implementation of AdamW is deprecated and will be removed in a future version. Use the PyTorch implementation torch.optim.AdamW instead, or set `no_deprecation_warning=True` to disable this warning
  warnings.warn(

  | Name  | Type           | Params
-----------------------------------------
0 | model | DiffusionModel | 72.4 M
-----------------------------------------
68.1 M    Trainable params
4.2 M     Non-trainable params
72.4 M    Total params
289.470   Total estimated model params size (MB)
Sanity Checking: |                                                                                                                  | 0/? [00:00<?, ?it/s]/home/xz479/miniconda3/envs/delibot/lib/python3.8/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=7` in the `DataLoader` to improve performance.
1
Traceback (most recent call last):
  File "/home/xz479/DiffuBot/train.py", line 148, in <module>
    train(args)
  File "/home/xz479/DiffuBot/train.py", line 134, in train
    trainer.fit(task, train_loader, val_loader)
  File "/home/xz479/miniconda3/envs/delibot/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 545, in fit
    call._call_and_handle_interrupt(
  File "/home/xz479/miniconda3/envs/delibot/lib/python3.8/site-packages/pytorch_lightning/trainer/call.py", line 43, in _call_and_handle_interrupt
    return trainer.strategy.launcher.launch(trainer_fn, *args, trainer=trainer, **kwargs)
  File "/home/xz479/miniconda3/envs/delibot/lib/python3.8/site-packages/pytorch_lightning/strategies/launchers/subprocess_script.py", line 102, in launch
    return function(*args, **kwargs)
  File "/home/xz479/miniconda3/envs/delibot/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 581, in _fit_impl
    self._run(model, ckpt_path=ckpt_path)
  File "/home/xz479/miniconda3/envs/delibot/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 990, in _run
    results = self._run_stage()
  File "/home/xz479/miniconda3/envs/delibot/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1034, in _run_stage
    self._run_sanity_check()
  File "/home/xz479/miniconda3/envs/delibot/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1063, in _run_sanity_check
    val_loop.run()
  File "/home/xz479/miniconda3/envs/delibot/lib/python3.8/site-packages/pytorch_lightning/loops/utilities.py", line 181, in _decorator
    return loop_run(self, *args, **kwargs)
  File "/home/xz479/miniconda3/envs/delibot/lib/python3.8/site-packages/pytorch_lightning/loops/evaluation_loop.py", line 134, in run
    self._evaluation_step(batch, batch_idx, dataloader_idx, dataloader_iter)
  File "/home/xz479/miniconda3/envs/delibot/lib/python3.8/site-packages/pytorch_lightning/loops/evaluation_loop.py", line 391, in _evaluation_step
    output = call._call_strategy_hook(trainer, hook_name, *step_args)
  File "/home/xz479/miniconda3/envs/delibot/lib/python3.8/site-packages/pytorch_lightning/trainer/call.py", line 309, in _call_strategy_hook
    output = fn(*args, **kwargs)
  File "/home/xz479/miniconda3/envs/delibot/lib/python3.8/site-packages/pytorch_lightning/strategies/strategy.py", line 402, in validation_step
    return self._forward_redirection(self.model, self.lightning_module, "validation_step", *args, **kwargs)
  File "/home/xz479/miniconda3/envs/delibot/lib/python3.8/site-packages/pytorch_lightning/strategies/strategy.py", line 628, in __call__
    wrapper_output = wrapper_module(*args, **kwargs)
  File "/home/xz479/miniconda3/envs/delibot/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/xz479/miniconda3/envs/delibot/lib/python3.8/site-packages/torch/nn/parallel/distributed.py", line 1156, in forward
    output = self._run_ddp_forward(*inputs, **kwargs)
  File "/home/xz479/miniconda3/envs/delibot/lib/python3.8/site-packages/torch/nn/parallel/distributed.py", line 1110, in _run_ddp_forward
    return module_to_run(*inputs[0], **kwargs[0])  # type: ignore[index]
  File "/home/xz479/miniconda3/envs/delibot/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/xz479/miniconda3/envs/delibot/lib/python3.8/site-packages/pytorch_lightning/strategies/strategy.py", line 621, in wrapped_forward
    out = method(*_args, **_kwargs)
  File "/home/xz479/DiffuBot/train.py", line 91, in validation_step
    loss = self.model.loss_fn(batch['dst_matrix'])
  File "/home/xz479/DiffuBot/train.py", line 42, in loss_fn
    output, epsilon, alpha_bar = self.forward(x, idx=idx, get_target=True)
  File "/home/xz479/DiffuBot/train.py", line 54, in forward
    used_alpha_bars = self.alpha_bars[idx][:, None, None]
RuntimeError: indices should be either on cpu or on the same device as the indexed tensor (cuda:0)
Sanity Checking DataLoader 0:   0%|                                                                                                 | 0/2 [00:00<?, ?it/s]0
0
0
torch.Size([])
Sanity Checking DataLoader 0:  50%|ā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–Œ                                            | 1/2 [00:00<00:00, 51.59it/s]0
0
0
torch.Size([])
Sanity Checking DataLoader 0: 100%|ā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆ| 2/2 [00:03<00:00,  0.55it/s]/home/xz479/miniconda3/envs/delibot/lib/python3.8/site-packages/pytorch_lightning/trainer/connectors/logger_connector/result.py:429: It is recommended to use `self.log('validation_epoch_average', ..., sync_dist=True)` when logging on epoch level in distributed setting to accumulate the metric across devices.
[rank: 1] Child process with PID 1109548 terminated with code 1. Forcefully terminating all other processes to avoid zombies 🧟                           
Killed

My code is below

class DiffusionModel(nn.Module):
    
    def __init__(self, device, beta_1, beta_T, T, decoder =None):
        '''
        The epsilon predictor of diffusion process.

        beta_1    : beta_1 of diffusion process
        beta_T    : beta_T of diffusion process
        T         : Diffusion Steps
        input_dim : a dimension of data

        '''

        super().__init__()
        self.device = device
        self.alpha_bars = torch.cumprod(1 - torch.linspace(start = beta_1, end=beta_T, steps=T), dim = 0).to(device = self.device)
        self.backbone = Transformer_Denoiser(T, batch_first=True).to(device = self.device)
        
        self.to(device = self.device)
    
    def loss_fn(self, x, idx=None):
        '''
        This function performed when only training phase.

        x          : real data if idx==None else perturbation data
        idx        : if None (training phase), we perturbed random index. Else (inference phase), it is recommended that you specify.

        '''
        output, epsilon, alpha_bar = self.forward(x, idx=idx, get_target=True)
        loss = (output - epsilon).square().mean()
        print(loss.shape)
        return loss
        
    def forward(self, x, idx=None, get_target=False):
        # print(x.shape)
        if idx == None:
            ### Generate random timestep index for noise sampling
            #idx_device = self.alpha_bars.get_device()
            idx = torch.randint(0, len(self.alpha_bars), (x.size(0), )).to(self.device)
            print(idx.get_device())
            used_alpha_bars = self.alpha_bars[idx][:, None, None]
            print(self.alpha_bars.get_device())
            print(used_alpha_bars.get_device())
            epsilon = torch.randn_like(x).to(self.device)
            #TODO lock noise with respect to turns
            x_tilde = torch.sqrt(used_alpha_bars) * x + torch.sqrt(1 - used_alpha_bars) * epsilon
            
        else:
            idx = torch.Tensor([idx for _ in range(x.size(0))]).long()
            x_tilde = x
        
        # print(x_tilde.shape)
        #return idx
        output = self.backbone(x_tilde, idx)
        
        return (output, epsilon, used_alpha_bars) if get_target else output

I am pretty sure all the tensors are on the same device as in the code I print out the devices of the alpha_bar tensor and the index tensor, and as shown in the error message, they are all on the device 0.

I’d be really appreciated if someone could help.

Thank you very much.

1 Like

I’m facing the same issue when using pytorch lightning 2.1.1 with ddp strategy and wandb logger. Could not find a workaround or a solution.
please help, Thanks.

Are you seeing these issues without PyTorch lightning in a pure PyTorch model?

I cant tell, since my way to use distributed training is by using pytorch lightning.
If I train on a single GPU, then there are no problems.
But, when I set devices=4, then I recieve the error:

ā€œ[rank: 2] child process with pid 4049 terminated with code -6. Forcefully terminating all other processes to avoid zombies :zombie:ā€

Thanks