Greetings,
I am writing a diffusion model. I am also using the pytorch-lightning framework. I have 2 GPUs. Somehow I encountered this error.
All distributed processes registered. Starting with 2 processes
----------------------------------------------------------------------------------------------------
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]
LOCAL_RANK: 1 - CUDA_VISIBLE_DEVICES: [0,1]
/home/xz479/miniconda3/envs/delibot/lib/python3.8/site-packages/transformers/optimization.py:407: FutureWarning: This implementation of AdamW is deprecated and will be removed in a future version. Use the PyTorch implementation torch.optim.AdamW instead, or set `no_deprecation_warning=True` to disable this warning
warnings.warn(
/home/xz479/miniconda3/envs/delibot/lib/python3.8/site-packages/transformers/optimization.py:407: FutureWarning: This implementation of AdamW is deprecated and will be removed in a future version. Use the PyTorch implementation torch.optim.AdamW instead, or set `no_deprecation_warning=True` to disable this warning
warnings.warn(
| Name | Type | Params
-----------------------------------------
0 | model | DiffusionModel | 72.4 M
-----------------------------------------
68.1 M Trainable params
4.2 M Non-trainable params
72.4 M Total params
289.470 Total estimated model params size (MB)
Sanity Checking: | | 0/? [00:00<?, ?it/s]/home/xz479/miniconda3/envs/delibot/lib/python3.8/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=7` in the `DataLoader` to improve performance.
1
Traceback (most recent call last):
File "/home/xz479/DiffuBot/train.py", line 148, in <module>
train(args)
File "/home/xz479/DiffuBot/train.py", line 134, in train
trainer.fit(task, train_loader, val_loader)
File "/home/xz479/miniconda3/envs/delibot/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 545, in fit
call._call_and_handle_interrupt(
File "/home/xz479/miniconda3/envs/delibot/lib/python3.8/site-packages/pytorch_lightning/trainer/call.py", line 43, in _call_and_handle_interrupt
return trainer.strategy.launcher.launch(trainer_fn, *args, trainer=trainer, **kwargs)
File "/home/xz479/miniconda3/envs/delibot/lib/python3.8/site-packages/pytorch_lightning/strategies/launchers/subprocess_script.py", line 102, in launch
return function(*args, **kwargs)
File "/home/xz479/miniconda3/envs/delibot/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 581, in _fit_impl
self._run(model, ckpt_path=ckpt_path)
File "/home/xz479/miniconda3/envs/delibot/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 990, in _run
results = self._run_stage()
File "/home/xz479/miniconda3/envs/delibot/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1034, in _run_stage
self._run_sanity_check()
File "/home/xz479/miniconda3/envs/delibot/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1063, in _run_sanity_check
val_loop.run()
File "/home/xz479/miniconda3/envs/delibot/lib/python3.8/site-packages/pytorch_lightning/loops/utilities.py", line 181, in _decorator
return loop_run(self, *args, **kwargs)
File "/home/xz479/miniconda3/envs/delibot/lib/python3.8/site-packages/pytorch_lightning/loops/evaluation_loop.py", line 134, in run
self._evaluation_step(batch, batch_idx, dataloader_idx, dataloader_iter)
File "/home/xz479/miniconda3/envs/delibot/lib/python3.8/site-packages/pytorch_lightning/loops/evaluation_loop.py", line 391, in _evaluation_step
output = call._call_strategy_hook(trainer, hook_name, *step_args)
File "/home/xz479/miniconda3/envs/delibot/lib/python3.8/site-packages/pytorch_lightning/trainer/call.py", line 309, in _call_strategy_hook
output = fn(*args, **kwargs)
File "/home/xz479/miniconda3/envs/delibot/lib/python3.8/site-packages/pytorch_lightning/strategies/strategy.py", line 402, in validation_step
return self._forward_redirection(self.model, self.lightning_module, "validation_step", *args, **kwargs)
File "/home/xz479/miniconda3/envs/delibot/lib/python3.8/site-packages/pytorch_lightning/strategies/strategy.py", line 628, in __call__
wrapper_output = wrapper_module(*args, **kwargs)
File "/home/xz479/miniconda3/envs/delibot/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "/home/xz479/miniconda3/envs/delibot/lib/python3.8/site-packages/torch/nn/parallel/distributed.py", line 1156, in forward
output = self._run_ddp_forward(*inputs, **kwargs)
File "/home/xz479/miniconda3/envs/delibot/lib/python3.8/site-packages/torch/nn/parallel/distributed.py", line 1110, in _run_ddp_forward
return module_to_run(*inputs[0], **kwargs[0]) # type: ignore[index]
File "/home/xz479/miniconda3/envs/delibot/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "/home/xz479/miniconda3/envs/delibot/lib/python3.8/site-packages/pytorch_lightning/strategies/strategy.py", line 621, in wrapped_forward
out = method(*_args, **_kwargs)
File "/home/xz479/DiffuBot/train.py", line 91, in validation_step
loss = self.model.loss_fn(batch['dst_matrix'])
File "/home/xz479/DiffuBot/train.py", line 42, in loss_fn
output, epsilon, alpha_bar = self.forward(x, idx=idx, get_target=True)
File "/home/xz479/DiffuBot/train.py", line 54, in forward
used_alpha_bars = self.alpha_bars[idx][:, None, None]
RuntimeError: indices should be either on cpu or on the same device as the indexed tensor (cuda:0)
Sanity Checking DataLoader 0: 0%| | 0/2 [00:00<?, ?it/s]0
0
0
torch.Size([])
Sanity Checking DataLoader 0: 50%|āāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāā | 1/2 [00:00<00:00, 51.59it/s]0
0
0
torch.Size([])
Sanity Checking DataLoader 0: 100%|āāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāā| 2/2 [00:03<00:00, 0.55it/s]/home/xz479/miniconda3/envs/delibot/lib/python3.8/site-packages/pytorch_lightning/trainer/connectors/logger_connector/result.py:429: It is recommended to use `self.log('validation_epoch_average', ..., sync_dist=True)` when logging on epoch level in distributed setting to accumulate the metric across devices.
[rank: 1] Child process with PID 1109548 terminated with code 1. Forcefully terminating all other processes to avoid zombies š§
Killed
My code is below
class DiffusionModel(nn.Module):
def __init__(self, device, beta_1, beta_T, T, decoder =None):
'''
The epsilon predictor of diffusion process.
beta_1 : beta_1 of diffusion process
beta_T : beta_T of diffusion process
T : Diffusion Steps
input_dim : a dimension of data
'''
super().__init__()
self.device = device
self.alpha_bars = torch.cumprod(1 - torch.linspace(start = beta_1, end=beta_T, steps=T), dim = 0).to(device = self.device)
self.backbone = Transformer_Denoiser(T, batch_first=True).to(device = self.device)
self.to(device = self.device)
def loss_fn(self, x, idx=None):
'''
This function performed when only training phase.
x : real data if idx==None else perturbation data
idx : if None (training phase), we perturbed random index. Else (inference phase), it is recommended that you specify.
'''
output, epsilon, alpha_bar = self.forward(x, idx=idx, get_target=True)
loss = (output - epsilon).square().mean()
print(loss.shape)
return loss
def forward(self, x, idx=None, get_target=False):
# print(x.shape)
if idx == None:
### Generate random timestep index for noise sampling
#idx_device = self.alpha_bars.get_device()
idx = torch.randint(0, len(self.alpha_bars), (x.size(0), )).to(self.device)
print(idx.get_device())
used_alpha_bars = self.alpha_bars[idx][:, None, None]
print(self.alpha_bars.get_device())
print(used_alpha_bars.get_device())
epsilon = torch.randn_like(x).to(self.device)
#TODO lock noise with respect to turns
x_tilde = torch.sqrt(used_alpha_bars) * x + torch.sqrt(1 - used_alpha_bars) * epsilon
else:
idx = torch.Tensor([idx for _ in range(x.size(0))]).long()
x_tilde = x
# print(x_tilde.shape)
#return idx
output = self.backbone(x_tilde, idx)
return (output, epsilon, used_alpha_bars) if get_target else output
I am pretty sure all the tensors are on the same device as in the code I print out the devices of the alpha_bar tensor and the index tensor, and as shown in the error message, they are all on the device 0.
Iād be really appreciated if someone could help.
Thank you very much.