Training time gradually increases per epoch

I’m training an EfficientNetV2 with the following training script:

for epoch in range(Config.num_epochs):
    print(f"{'-'*20} EPOCH: {epoch+1}/{Config.num_epochs} {'-'*20}")
    
    model.train()

    train_prog_bar = tqdm(train_loader, total=len(train_loader))

    train_labels = []
    train_preds = []
    
    running_loss = 0.0
    train_acc = 0.0
    
    for i, (x_train, y_train) in enumerate(train_prog_bar, 0):
        x_train = x_train.to(device).float()
        y_train = y_train.to(device).long()

        optim.zero_grad()

        pred = model(x_train.unsqueeze(1))
        train_loss = train_loss_fn(pred, y_train)

        train_loss.backward()
        optim.step()
                
        scheduler.step()

        running_loss += train_loss.item()

        pred = torch.argmax(pred, 1).detach().cpu().numpy()
        label = y_train.detach().cpu().numpy()
        
        train_acc += (pred == label).sum() / Config.batch_size

        train_labels += [pred]
        train_preds += [label]
        
        train_prog_bar.set_description(f'loss: {train_loss.item():.4f}')
        
    running_loss = running_loss / len(train_loader)
    train_losses.append(running_loss / 100)
    
    print(f"Final Training Loss: {running_loss:.4f}")
    print(f"Final Training Accuracy: {train_acc/len(train_loader)}")
    
    running_loss = 0.0
    train_acc = 0.0
    
    if epoch % 10 == 9:
        print(f"Saving Model {int((epoch+1)/10)} out of {int(Config.num_epochs/10)}")
        torch.save(model.state_dict(), f'../models/efficientnetv2_s_0{int(((epoch+1)/10)-1)}.pt')
            
    train_labels = np.concatenate(train_labels)
    train_preds = np.concatenate(train_preds)
            
    valid_prog_bar = tqdm(valid_loader, total=len(valid_loader))
    
    valid_labels = []
    valid_preds = []
    
    valid_acc = 0.0
    
    model.eval()
        
    with torch.no_grad():
        for x_valid, y_valid in valid_prog_bar:
            x_valid = x_valid.to(device).float()
            y_valid = y_valid.to(device).long()
                    
            valid_pred = model(x_valid.unsqueeze(1))
            valid_loss = valid_loss_fn(valid_pred, y_valid)
                
            running_loss += valid_loss.item()
                
            valid_pred = torch.argmax(valid_pred, 1).detach().cpu().numpy()
            valid_label = y_valid.detach().cpu().numpy()
            
            valid_acc += (valid_pred == valid_label).sum() / Config.batch_size
                
            valid_labels += [valid_label]
            valid_preds += [valid_pred]
    
            valid_prog_bar.set_description(desc=f"loss: {valid_loss.item():.4f}")
            
        running_loss = running_loss / len(valid_loader)
        valid_losses.append(running_loss)
        
        print(f"Final Validation Loss: {running_loss:.4f}")
        print(f"Final Validation Accuracy: {valid_acc/len(valid_loader)}")
        
        running_loss = 0.0
        valid_acc = 0.0
            
        valid_labels = np.concatenate(valid_labels)
        valid_preds = np.concatenate(valid_preds) 
        
    gc.collect()
    torch.cuda.empty_cache()

I was having this problem before and tried adding torch.cuda.empty_cache() which unfortunately did not work. I have no idea what could be causing this, and have never experienced this issue before in PyTorch. The training time increases from 22 minutes, to 22.5 minutes, to 23 minutes, to 24 minutes, to 27.5 minutes, to 35 minutes, to 47 minutes, etc. Since I’m a beginner with PyTorch, please share exact code samples, not theoretical concepts. I have provided the whole notebook for further debugging, but sadly I can’t share the data. Thanks in advance.

import gc
import math

import warnings
warnings.filterwarnings("ignore")

import numpy as np
import pandas as pd

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torch.optim.lr_scheduler import _LRScheduler

import timm

import cv2 as cv
import albumentations as A

import pydicom
from pydicom.pixel_data_handlers.util import apply_voi_lut

from tqdm.auto import tqdm

def read_img(path, voi_lut=True, fix_monochrome=True):
    ds = pydicom.read_file(path)
    
    if voi_lut:
        img = apply_voi_lut(ds.pixel_array, ds)
    else:
        img = ds.pixel_array
        
    if fix_monochrome and ds.PhotometricInterpretation == "MONOCHROME1":
        img = np.amax(img) - img
        
    img = img - np.min(img)
    img = img / np.max(img)
    img = (img * 255).astype(np.uint8)
    
    return img

def resize_img(img, size=512, pad=False, interp=cv.INTER_LANCZOS4):
    if pad:
        max_width = 4891
        max_height = 4891
        
        img = np.pad(img, ((0, max_height - img.shape[0]), (0, max_width - img.shape[1])))

    img = cv.resize(img, dsize=(size, size), interpolation=interp)
    
    return img

def augment_img(img, clahe=True, albumentations=True):
    if clahe:
        clahe = cv.createCLAHE(clipLimit=15.0, tileGridSize=(8,8))
        img = clahe.apply(img)
    else:
        img = cv.equalizeHist(img)
        
    if albumentations:
        transform = A.Compose([
            A.HorizontalFlip(p=0.5),
            A.ShiftScaleRotate(shift_limit=0.2, scale_limit=0.2, rotate_limit=20, p=1.0),
            A.RandomBrightnessContrast(brightness_limit=0.2, contrast_limit=0.2, p=1.0),
            A.OneOf([
                A.GaussianBlur(blur_limit=(11, 27)),
                A.GaussNoise(var_limit=(2000, 6000))
            ], p=0.3),
            A.CoarseDropout(max_holes=8, max_height=64, max_width=64, min_holes=4, min_height=32, min_width=32, p=0.3)
        ])
        
        img = transform(image=img)["image"]
        
    return img

class Config:
    model_name = 'efficientnetv2_m'
    img_size = 512
    batch_size = 8
    num_epochs = 50

class EffNet(nn.Module):
    def __init__(self, num_classes=4, model_name=Config.model_name, img_channels=1, img_size=Config.img_size):
        super(EffNet, self).__init__()
        self.model = timm.create_model(model_name, num_classes=num_classes, in_chans=img_channels)
        
    def forward(self, x):
        x = self.model(x)
        return x

class SIIMData(Dataset):
    def __init__(self, data, is_train=True, img_size=Config.img_size):
        self.data = data.sample(frac=1).reset_index(drop=True)
        self.is_train = is_train
        self.img_size = img_size
        
    def __getitem__(self, idx):
        img_id = self.data['id'].values[idx]
        img_path = self.data['path'].values[idx]
        
        img = read_img(img_path)
        img = resize_img(img, self.img_size)
        
        if self.is_train and self.data['augment'].values[idx]:
            img = augment_img(img)
        else:
            img = augment_img(img, albumentations=False)
            
        if self.is_train:
            label = self.data[self.data['id'] == img_id].values.tolist()[0][4:8]
        else:
            label = self.data[self.data['id'] == img_id].values.tolist()[0][3:7]
            
        label = np.array(label)
        label = np.where(label==1)[0][0]
        
        return torch.tensor(img), torch.tensor(label)
    
    def __len__(self):
        return len(self.data)

class CosineAnnealingWarmRestarts(_LRScheduler):
    def __init__(self, optimizer, T_max, T_mult=1, eta_min=0, last_epoch=-1):
        self.T_max = T_max
        self.T_mult = T_mult
        self.Te = self.T_max
        self.eta_min = eta_min
        self.current_epoch = last_epoch
        
        super(CosineAnnealingWarmRestarts, self).__init__(optimizer, last_epoch)

    def get_lr(self):
        new_lrs = [self.eta_min + (base_lr - self.eta_min) * 
                   (1 + math.cos(math.pi * self.current_epoch / self.Te)) / 2 
                   for base_lr in self.base_lrs]
        
        return new_lrs
    
    def step(self, epoch=None):
        if epoch is None:
            epoch = self.last_epoch + 1
            
        self.last_epoch = epoch
        self.current_epoch += 1
        
        for param_group, lr in zip(self.optimizer.param_groups, self.get_lr()):
            param_group['lr'] = lr
        
        if self.current_epoch == self.Te:
            self.current_epoch = 0
            
            self.Te = int(self.Te * self.T_mult)
            self.T_max = self.T_max + self.Te

if torch.cuda.is_available():
    device = torch.device('cuda:0')
else:
    device = torch.device('cpu')
        
train_df = pd.read_csv('../input/siim-covid19-data/train_data.csv')
valid_df = pd.read_csv('../input/siim-covid19-data/valid_data.csv')

train_data = SIIMData(data=train_df)
valid_data = SIIMData(data=valid_df, is_train=False)

train_loader = DataLoader(
    train_data,
    batch_size=Config.batch_size,
    shuffle=True,
    num_workers=4,
    pin_memory=True
)

valid_loader = DataLoader(
    valid_data,
    batch_size=Config.batch_size,
    shuffle=False
)

model = EffNet().to(device)

optim = torch.optim.SGD(model.parameters(), lr=1.0, momentum=0.9)
scheduler = CosineAnnealingWarmRestarts(optim, T_max=10*len(train_loader), eta_min=0.0001)

train_loss_fn = nn.CrossEntropyLoss()
valid_loss_fn = nn.CrossEntropyLoss()

train_losses = []
valid_losses = []

for epoch in range(Config.num_epochs):
    print(f"{'-'*20} EPOCH: {epoch+1}/{Config.num_epochs} {'-'*20}")
    
    model.train()

    train_prog_bar = tqdm(train_loader, total=len(train_loader))

    train_labels = []
    train_preds = []
    
    running_loss = 0.0
    train_acc = 0.0
    
    for i, (x_train, y_train) in enumerate(train_prog_bar, 0):
        x_train = x_train.to(device).float()
        y_train = y_train.to(device).long()

        optim.zero_grad()

        pred = model(x_train.unsqueeze(1))
        train_loss = train_loss_fn(pred, y_train)

        train_loss.backward()
        optim.step()
                
        scheduler.step()

        running_loss += train_loss.item()

        pred = torch.argmax(pred, 1).detach().cpu().numpy()
        label = y_train.detach().cpu().numpy()
        
        train_acc += (pred == label).sum() / Config.batch_size

        train_labels += [pred]
        train_preds += [label]
        
        train_prog_bar.set_description(f'loss: {train_loss.item():.4f}')
        
    running_loss = running_loss / len(train_loader)
    train_losses.append(running_loss / 100)
    
    print(f"Final Training Loss: {running_loss:.4f}")
    print(f"Final Training Accuracy: {train_acc/len(train_loader)}")
    
    running_loss = 0.0
    train_acc = 0.0
    
    if epoch % 10 == 9:
        print(f"Saving Model {int((epoch+1)/10)} out of {int(Config.num_epochs/10)}")
        torch.save(model.state_dict(), f'../models/efficientnetv2_s_0{int(((epoch+1)/10)-1)}.pt')
            
    train_labels = np.concatenate(train_labels)
    train_preds = np.concatenate(train_preds)
            
    valid_prog_bar = tqdm(valid_loader, total=len(valid_loader))
    
    valid_labels = []
    valid_preds = []
    
    valid_acc = 0.0
    
    model.eval()
        
    with torch.no_grad():
        for x_valid, y_valid in valid_prog_bar:
            x_valid = x_valid.to(device).float()
            y_valid = y_valid.to(device).long()
                    
            valid_pred = model(x_valid.unsqueeze(1))
            valid_loss = valid_loss_fn(valid_pred, y_valid)
                
            running_loss += valid_loss.item()
                
            valid_pred = torch.argmax(valid_pred, 1).detach().cpu().numpy()
            valid_label = y_valid.detach().cpu().numpy()
            
            valid_acc += (valid_pred == valid_label).sum() / Config.batch_size
                
            valid_labels += [valid_label]
            valid_preds += [valid_pred]
    
            valid_prog_bar.set_description(desc=f"loss: {valid_loss.item():.4f}")
            
        running_loss = running_loss / len(valid_loader)
        valid_losses.append(running_loss)
        
        print(f"Final Validation Loss: {running_loss:.4f}")
        print(f"Final Validation Accuracy: {valid_acc/len(valid_loader)}")
        
        running_loss = 0.0
        valid_acc = 0.0
            
        valid_labels = np.concatenate(valid_labels)
        valid_preds = np.concatenate(valid_preds) 
        
    gc.collect()
    torch.cuda.empty_cache()

I would try to reduce the code until the slowdown isn’t observed anymore in order to isolate the issue, as I cannot see any obvious errors.
I.e. remove all metric calculation, loss prints etc. and check the actual training alone.

torch.cuda.empty_cache() will not avoid out of memory issues, but will release the unused device memory stored in the cache, which will slow down your code the memory needs to be reallocated again.

Sure, let me try doing that and get back to you.

The slowdown was still visible after stripping down the code.

Thanks for the update! I assume you’ve then created a code snippet containing the model training loop alone and are still seeing the slowdown in each iteration?
If so, could you please post the code here so that we could take a look?