I am facing an issue where my training loop suddenly freezes after certain epoch, when my dataset is large. I am currently training a model using the BYOL strategy, when I am running a test run with smaller dataset (6 datapoints), the training loop freezes after 6th epoch and continues after sometime, but when I use the larger dataset, the training loop freezes after 6th epoch and the system memory consumption keeps increasing until it is killed.
def train_one_epoch(self, train_loader:DataLoader, optimizer:optim.Optimizer, **kwargs) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
"""
Train a single loop on the train_loader.
Args:
train_loader (DataLoader): training dataloader
optimizer (optim.Optimizer): optimizer
Returns:
torch.Tensor, Tuple[torch.Tensor, torch.Tensor]: validation loss, (augmented view 1, augmented view 2) for tensorboard logging
"""
running_training_loss = 0
scaler = amp.grad_scaler.GradScaler(device.type)
self.online_model.train()
pbar = tqdm(train_loader, desc='Validating...', total=len(train_loader), unit='batch', leave=False, position=1)
for _, (view1, view2) in enumerate(pbar):
view1, view2 = view1.to(device), view2.to(device)
optimizer.zero_grad(set_to_none=True)
if kwargs['mixed_precision'] == True:
with amp.autocast_mode.autocast(device.type):
o1 = self.online_model(view1)
o2 = self.online_model(view2)
with torch.no_grad():
t1 = self.target_model(view1)
t2 = self.target_model(view2)
loss = (self.calculate_loss(o1, t2.detach()) + self.calculate_loss(o2, t1.detach())).mean()
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
else:
o1 = self.online_model(view1)
o2 = self.online_model(view2)
with torch.no_grad():
t1 = self.target_model(view1)
t2 = self.target_model(view2)
loss = (self.calculate_loss(o1, t2.detach()) + self.calculate_loss(o2, t1.detach())).mean()
loss.backward()
optimizer.step()
running_training_loss += loss.item() / len(train_loader)
pbar.set_postfix({'Training loss': running_training_loss})
pbar.update()
torch.cuda.empty_cache()
return running_training_loss, (view1[0], view2[0])
def validate(self, validation_loader:DataLoader, lr_scheduler:optim.lr_scheduler.LRScheduler) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
"""
Run a single validation loop for the model
Args:
validation_loader (DataLoader): validation loader
lr_scheduler (LRScheduler): learning rate scheduler
Returns:
torch.Tensor, Tuple[torch.Tensor, torch.Tensor]: validation loss, (augmented view 1, augmented view 2) for tensorboard logging
"""
running_validation_loss = 0
self.online_model.eval()
self.target_model.eval()
pbar = tqdm(validation_loader, desc='Validating...', total=len(validation_loader), unit='batch', leave=False, position=1)
with torch.no_grad():
for _, (view1, view2) in enumerate(pbar):
view1, view2 = view1.to(device), view2.to(device)
o1 = self.online_model(view1)
o2 = self.online_model(view2)
t1 = self.target_model(view1)
t2 = self.target_model(view2)
loss = (self.calculate_loss(o1, t2) + self.calculate_loss(o2, t1)).mean()
lr_scheduler.step(loss)
running_validation_loss += loss.item() / len(validation_loader)
pbar.set_postfix({'Validation loss': running_validation_loss})
pbar.update()
torch.cuda.empty_cache()
return running_validation_loss, (view1[0], view2[0])
def train(self,
dataloader:Union[DataLoader, Tuple[DataLoader, DataLoader]],
epochs:int=100,
lr_rate:float=1e-4,
optimizer:optim.Optimizer=optim.Adam,
lr_scheduler:optim.lr_scheduler.LRScheduler=optim.lr_scheduler.ReduceLROnPlateau,
logger:Logger=None,
save_dir:str="weights/SegFormer/encoder/",
**kwargs) -> nn.Module:
r"""
Implements the BYOL training strategy
Args:
dataloader (Union[DataLoader, Tuple[DataLoader, DataLoader]]): traning and validation loader (if any)
epochs (int): number of epochs
lr_rate (float): learning rate
optimizer (optim.Optimizer): training optimizer
lr_scheduler (optim.lr_scheduler.LRScheduler): learning rate scheduler
logger (Logger): tensorboard logger
save_dir (str): model save directory
Returns:
nn.Module: trained encoder
"""
train_loader, validation_loader = None, None
if isinstance(dataloader, list) or isinstance(dataloader, tuple):
train_loader, validation_loader = dataloader
else:
train_loader = dataloader
if logger is not None:
logger.log_graph(input=torch.randn(1, 3, self.target_model.image_size, self.target_model.image_size).to(device))
optimizer = optimizer(self.online_model.parameters(), lr_rate, eps=1e-4, weight_decay = kwargs['weight_decay'])
lr_scheduler = lr_scheduler(optimizer, mode="min", threshold=0.01, patience=3, min_lr=1e-5, eps=1e-3)
profiler = logger.gpu_profiler()
profiler.start()
pbar = tqdm(range(epochs), desc="Training...", unit="epoch", bar_format="{l_bar}{bar}{r_bar}", dynamic_ncols=True, colour="#1f5e2a", position=0, leave=True)
best_validation_loss = 2.
for epoch in pbar:
training_loss, (view1, view2) = self.train_one_epoch(train_loader, optimizer, **kwargs)
profiler.step()
pbar.set_postfix({"Training Loss": training_loss})
if validation_loader is not None:
validation_loss, (view1, view2) = self.validate(validation_loader, lr_scheduler)
pbar.set_postfix({"Training Loss": training_loss,
"Validation Loss": validation_loss,
"Learning Rate": lr_scheduler.get_last_lr()[-1]})
if logger is not None:
logger.log_scaler(epoch, "Learning Rate", lr_scheduler.get_last_lr()[-1])
logger.log_scaler(epoch, "Metrics/Training Loss", training_loss)
logger.log_scaler(epoch, "Metrics/Validation Loss", validation_loss)
logger.log_histogram(epoch, "Weights Distribution/Online Model Embedding Layer", [p for p in self.online_model.named_parameters()][-4][1])
logger.log_histogram(epoch, "Weights Distribution/Target Model Embedding Layer", [p for p in self.target_model.named_parameters()][-2][1])
if epoch % 10 == 0:
logger.log_images(epoch, "Images/View 1", view1)
logger.log_images(epoch, "Images/View 2", view2)
with torch.no_grad():
logger.log_images(epoch, "Model Ouput/Online Model", view_attention(self.online_model_encoder.encoder(view1.unsqueeze(0))[0]))
logger.log_images(epoch, "Model Output/Target Model", view_attention(self.target_model.encoder(view1.unsqueeze(0))[0]))
if epoch % 10 == 0:
if not os.path.exists(save_dir):
os.makedirs(save_dir, exist_ok=True)
if validation_loss < best_validation_loss:
save_model(self.online_model_encoder, f"SegFormer_{self.online_model_encoder.encoder_variant}_{train_loader.batch_size}_{epoch}", save_dir)
best_validation_loss = validation_loss
pbar.update()
profiler.stop()
if logger is not None:
logger.log_hparams(hparams={"batch_size": train_loader.batch_size, "learning_rate": lr_rate},
metrics={"Training Loss": training_loss, "Validation Loss": validation_loss}
)
return self.online_model_encoder