Training is very slow after one epoch

I am trying to train a deep learning model for segmentation using resnet as backbone and feature pyramid network as decoder. The model is being trained on cityscapes dataset. The first epoch gets completed within 30 minutes which is normal I guess given that I am iterating over the train dataset twice in each epoch. However, after one epoch training is very slow and never completes even if 3-4 hours passed. I am guessing the problem lies on data loading. Here’s my data loader.
sup_loader = data.DataLoader(sup_data, batch_size=args.compute_batch_size,
shuffle=False, num_workers=args.num_workers, pin_memory=True,
collate_fn=collate_eval, worker_init_fn=worker_init_fn(args.seed))
I am training on google colab pro. What might be the reason behind this anamoly?

@Rakibul_Hasan_Rajib can you share the train function as well as data loader parameter, because the data loader call looks fine, but if the number of worker is higher then you will face this problem.

def run_mini_batch_kmeans(args, sup_loader, model,view):

kmeans_loss  = AverageMeter()
faiss_module = get_faiss_module(args)
data_count   = np.zeros(args.num_classes)

featslist   = []
num_batches = 0

sup_loader.dataset.view = view
unsup_loader.dataset.view = view

centroids = np.zeros((args.num_classes,args.in_dim)).astype('float32')

with torch.no_grad():
    for i_batch, (indice, image,label) in tqdm(enumerate(sup_loader)):
        image = image.cuda(non_blocking=True)
        feats = model(image)
        label = label.cuda(non_blocking=True)

        if args.metric_test == 'cosine':
            feats = F.normalize(feats, dim=1, p=2)
        feats = F.interpolate(feats,image.shape[-2:],mode='bilinear',align_corners=False)
        feats = feature_flatten(feats).detach().cpu().numpy().astype('float32')
        label = label.flatten().type(torch.LongTensor)

        for k in range(args.num_classes):
            idx_k = np.where(label == k)[0]
            data_count[k] += len(idx_k)
            centroids[k] += np.sum(feats[idx_k],axis=0)
for k in range(args.num_classes):
    centroids[k] /= data_count[k] + 1e-6
return centroids