Hi I am facing issue in training my model in DDP setting. I have a warmup phase and then training phase. The warmup phase runs smoothly without any issue, while in the training phase, I am training and validating at each epoch. The first epoch of train and validation runs smoothly, while when the second epoch begins, it gets stuck at the last iteration of second epoch. This is my code
def warmup_one_epoch(self, dataloader:DataLoader, lr_scheduler:optim.lr_scheduler.LRScheduler) -> None:
self.model.train()
self.optimizer.zero_grad(set_to_none=True)
metric = {}
running_losses = {}
scaler = amp.grad_scaler.GradScaler(self.__ddp_config['device'].type)
if is_main_process():
progress = tqdm(dataloader, desc='Warming_up...', leave=False, position=1, colour='#0b83de', dynamic_ncols=True)
else:
progress = dataloader
with self.model.join():
for epoch, batch in enumerate(progress):
image = batch[0].to(self.__ddp_config['device'], non_blocking=True)
target = move_to_device(batch[1], self.__ddp_config['device'], non_blocking=True)
if ((epoch + 1) % self.kwargs['num_updates']) != 0 and hasattr(self.model, 'no_sync'):
sync_context = self.model.no_sync()
else:
sync_context = contextlib.nullcontext()
with sync_context:
if self.kwargs['mixed_precision']:
with amp.autocast_mode.autocast(device_type=self.__ddp_config['device'].type, dtype=self.kwargs['autocast_dtype']):
pred = self.model(image)
loss, _ = self.criterion(pred, target)
total_loss = (loss['cls_loss'] + loss['mask_loss']) / self.kwargs['num_updates']
scaler.scale(total_loss).backward()
else:
with amp.autocast_mode.autocast(device_type=self.__ddp_config['device'].type, dtype=self.kwargs['autocast_dtype']):
pred = self.model(image)
loss, _ = self.criterion(pred, target)
total_loss = (loss['cls_loss'] + loss['mask_loss']) / self.kwargs['num_updates']
total_loss.backward()
if ((epoch+1)%self.kwargs['num_updates'] == 0) or (epoch+1 == len(dataloader)):
if self.kwargs['mixed_precision']:
scaler.unscale_(self.optimizer)
# nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.)
scaler.step(self.optimizer)
scaler.update()
else:
# nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.)
self.optimizer.step()
self.optimizer.zero_grad(set_to_none=True)
for key, value in loss.items():
if not key in running_losses.keys():
running_losses[key] = 0.
running_losses[key] += self.reduce_metrics(value.detach() / self.kwargs['num_updates']).item() / len(dataloader)
if is_main_process():
current_lr = self.optimizer.param_groups[0]['lr']
progress.set_postfix({
'lr': f"{current_lr:.3f}",
**{k: f"{v:.3f}" for k, v in running_losses.items()}
})
if is_main_process():
progress.close()
running_losses['total_loss'] = running_losses['cls_loss'] + running_losses['mask_loss']
lr_scheduler.step()
metric["warmup_loss"] = running_losses
torch.cuda.synchronize()
gc.collect()
torch.cuda.empty_cache()
torch.clear_autocast_cache()
return metric
def train_one_epoch(self, dataloader:DataLoader) -> dict:
self.model.train()
self.optimizer.zero_grad(set_to_none=True)
metric = {}
running_losses = {}
running_metrics = {}
scaler = amp.grad_scaler.GradScaler(self.__ddp_config['device'].type)
if is_main_process():
progress = tqdm(dataloader, desc='Training...', leave=False, position=1, colour='#0b83de', dynamic_ncols=True)
else:
progress = dataloader
with self.model.join():
for epoch, batch in enumerate(progress):
image = batch[0].to(self.__ddp_config['device'], non_blocking=True)
target = move_to_device(batch[1], self.__ddp_config['device'], non_blocking=True)
if ((epoch + 1) % self.kwargs['num_updates']) != 0 and hasattr(self.model, 'no_sync'):
sync_context = self.model.no_sync()
else:
sync_context = contextlib.nullcontext()
with sync_context:
if self.kwargs['mixed_precision']:
with amp.autocast_mode.autocast(device_type=self.__ddp_config['device'].type, dtype=self.kwargs['autocast_dtype']):
pred = self.model(image)
loss, idx = self.criterion(pred, target)
total_loss = (loss['cls_loss'] + loss['mask_loss']) / self.kwargs['num_updates']
scaler.scale(total_loss).backward()
else:
with amp.autocast_mode.autocast(device_type=self.__ddp_config['device'].type, dtype=self.kwargs['autocast_dtype']):
pred = self.model(image)
loss, idx = self.criterion(pred, target)
total_loss = (loss['cls_loss'] + loss['mask_loss']) / self.kwargs['num_updates']
total_loss.backward()
if ((epoch+1)%self.kwargs['num_updates'] == 0) or (epoch+1 == len(dataloader)):
if self.kwargs['mixed_precision']:
scaler.unscale_(self.optimizer)
scaler.step(self.optimizer)
scaler.update()
else:
self.optimizer.step()
self.optimizer.zero_grad(set_to_none=True)
for key, value in loss.items():
if not key in running_losses.keys():
running_losses[key] = 0.
running_losses[key] += self.reduce_metrics(value / self.kwargs['num_updates']).item() / len(dataloader)
# cls_metric, instance_metric, semantic_metric = self.metric(pred, target, idx)
# for (cls_key, cls_val), (sem_key, sem_val) in zip(cls_metric.items(), semantic_metric.items()):
# cls_metric[cls_key] = self.reduce_metrics(cls_val.detach()).item()
# semantic_metric[sem_key] = self.reduce_metrics(sem_val.detach()).item()
# running_metrics['cls_metrics'] = cls_metric
# running_metrics['instance_metrics'] = instance_metric
# running_metrics['semantic_metrics'] = semantic_metric
# running_metrics = add_metrics(running_metrics, running_metrics, len(dataloader))
if is_main_process():
current_lr = self.optimizer.param_groups[0]['lr']
progress.set_postfix({
'lr': f"{current_lr:.3f}",
**{k: f"{v:.3f}" for k, v in running_losses.items()}
})
if is_main_process():
progress.close()
running_losses['total_loss'] = running_losses['cls_loss'] + running_losses['mask_loss']
metric['train_loss'] = running_losses
metric['train_metrics'] = running_metrics
torch.cuda.synchronize()
gc.collect()
torch.cuda.empty_cache()
torch.clear_autocast_cache()
return metric
def validate_one_epoch(self, dataloader:DataLoader) -> dict:
self.model.eval()
metric = {}
running_losses = {}
running_metrics = {}
if is_main_process():
progress = tqdm(dataloader, desc='Validating...', leave=False, position=1, colour='#0b83de', dynamic_ncols=True)
else:
progress = dataloader
with torch.no_grad():
with self.model.join():
for _, batch in enumerate(progress):
image = batch[0].to(self.__ddp_config['device'], non_blocking=True)
target = move_to_device(batch[1], self.__ddp_config['device'], non_blocking=True)
if self.kwargs['mixed_precision']:
with amp.autocast_mode.autocast(device_type=self.__ddp_config['device'].type, dtype=self.kwargs['autocast_dtype']):
pred = self.model(image)
loss, idx = self.criterion(pred, target)
else:
with amp.autocast_mode.autocast(device_type=self.__ddp_config['device'].type, dtype=self.kwargs['autocast_dtype']):
pred = self.model(image)
loss, idx = self.criterion(pred, target)
for key, value in loss.items():
if not key in running_losses.keys():
running_losses[key] = 0.
running_losses[key] += self.reduce_metrics(value.detach() / self.kwargs['num_updates']).item() / len(dataloader)
# cls_metric, instance_metric, semantic_metric = self.metric(pred, target, idx)
# for (cls_key, cls_val), (sem_key, sem_val) in zip(cls_metric.items(), semantic_metric.items()):
# cls_metric[cls_key] = self.reduce_metrics(cls_val.detach()).item()
# semantic_metric[sem_key] = self.reduce_metrics(sem_val.detach()).item()
# running_metrics['cls_metrics'] = cls_metric
# running_metrics['instance_metrics'] = instance_metric
# running_metrics['semantic_metrics'] = semantic_metric
# running_metrics = add_metrics(running_metrics, running_metrics, len(dataloader))
if is_main_process():
current_lr = self.lr_scheduler.get_last_lr()[0]
progress.set_postfix({
'lr': f"{current_lr:.3f}",
**{k: f"{v:.3f}" for k, v in running_losses.items()}
})
if is_main_process():
progress.close()
running_losses['total_loss'] = running_losses['cls_loss'] + running_losses['mask_loss']
self.lr_scheduler.step(running_losses['total_loss'])
metric['val_loss'] = running_losses
metric['val_metrics'] = running_metrics
torch.cuda.synchronize()
gc.collect()
torch.clear_autocast_cache()
torch.cuda.empty_cache()
return metric
def __call__(self, dataloader:Union[DataLoader, Tuple[DataLoader, DataLoader]]) -> None:
train_loader = val_loader = None
if len(dataloader) == 2:
train_loader, val_loader = dataloader
elif len(dataloader) == 1:
train_loader = dataloader
else:
raise RuntimeError(f"length of dataloader is {len(dataloader)}, supports only 1 or 2")
if self.logger is not None and is_main_process():
self.logger.log_graph(torch.randn(1, 3, self.image_size, self.image_size,
device=next(self.model.parameters()).device))
metric = {}
progress = None
best_val_loss = float("inf")
if self.kwargs['warmup_epochs'] > 0 and not os.path.isfile(f"{self.kwargs['snapshot_dir']}/{self.model.module.variant}.pt"):
warmup_scheduler = optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda=lambda epoch: self.warmup_lr_schedule(epoch=epoch, warmup_epochs=self.kwargs['warmup_epochs'], initial_lr=self.kwargs['lr']/(dist.get_world_size()*self.kwargs['warmup_epochs']/2), final_lr=self.kwargs['lr']/dist.get_world_size()))
if is_main_process():
progress = tqdm(range(self.kwargs['warmup_epochs']), desc='Warmup Phase...', leave=True, position=0, colour='#1066a3', dynamic_ncols=True)
else:
progress = range(self.kwargs['warmup_epochs'])
for epoch in progress:
train_loader.sampler.set_epoch(epoch)
warmup_metrics = self.warmup_one_epoch(train_loader, self.optimizer, warmup_scheduler)
if is_main_process():
current_lr = warmup_scheduler.get_last_lr()[0]
progress.set_postfix({
'lr': f"{current_lr:.3f}",
**{k: f"{v:.3f}" for k, v in warmup_metrics['warmup_loss'].items()}
})
dist.barrier()
if is_main_process():
self._save_snapshot(epoch=0)
progress.close()
dist.barrier()
if is_main_process():
print("\n")
progress = tqdm(range(self.epoch, self.kwargs['epochs']), desc='Train Validation Phase...', leave=True, position=0, colour='#1066a3')
else:
progress = range(self.epoch, self.kwargs['epochs'])
for epoch in progress:
train_loader.sampler.set_epoch(epoch)
train_metrics = self.train_one_epoch(train_loader)
if val_loader is not None:
val_loader.sampler.set_epoch(epoch)
val_metrics = self.validate_one_epoch(val_loader)
metric['learning_rate'] = self.optimizer.param_groups[0]['lr'] * dist.get_world_size()
metric['train_metric'] = train_metrics
metric['val_metric'] = val_metrics if val_loader is not None else None
if is_main_process():
current_lr = self.optimizer.param_groups[0]['lr']
progress.set_postfix({
'lr': f"{current_lr:.3f}",
'train_loss': f"{train_metrics['train_loss']['total_loss']:.4f}",
'val_loss': f"{val_metrics['val_loss']['total_loss']:.4f}"
})
if (epoch+1) % self.kwargs['snapshot_log_interval'] == 0:
self._save_snapshot(epoch+1)
if val_metrics['val_loss']['total_loss'] < best_val_loss:
best_val_loss = val_metrics['val_loss']['total_loss']
self._check_dir(self.kwargs['checkpoint_dir'])
save_model(self.model.module, f"RTSeg_{self.model.module.variant}_batch_{self.kwargs['batch_size']}_best", self.kwargs['checkpoint_dir'])
self._check_dir(self.kwargs['log_dir'])
self.log_metrics(metric, epoch+1)
self.epoch = epoch + 1
dist.barrier()
I am using torch==2.5.1 with 3 Ada 6000 GPUs