Loss.backward() time increases for each batch

I created a network for training on images. It uses attention on parts of the images and used a transformer encoder on the final part. But the loss.backward() times increases for each batch during training. It remains constant initial some epochs. But after that the time increases rapidly.

import torch
import torch.nn as nn
from icecream import ic
import os

# class writeto:
#     def __init__(self):
#         self.f =open("debug.txt","w")

#     def __call__(self,...)
         
#     def writeTo()

class Flatten(nn.Module):
    def __init__(self, ):
        super().__init__()

    def forward(self, x):
        bs = x.shape[0]
        return x.reshape((bs, -1))


class ConvModel(torch.nn.Module):
    def __init__(self):
        super().__init__()

        self.seq = nn.Sequential(
            nn.Conv2d(1, 32, 3, stride=1),
            nn.ReLU(),
            nn.MaxPool2d(2, stride=2),
            nn.ReLU(),
            nn.Conv2d(32, 64, 3, stride=1),
            nn.ReLU(),
            nn.MaxPool2d(2, stride=2),
            Flatten(),
            nn.Linear(57600, 256),
            nn.ReLU(),
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Linear(128, 25),
            nn.Softmax(dim=-1)
        )

    def forward(self, x):
       #ic(x.shape)
        u = self.seq(x)
       #ic(u.shape)
        return u


class SeparateLinear(torch.nn.Module):
    def __init__(self, n, ins, out):
        super().__init__()

        self.Linears = {}
        for i in range(n):
            self.Linears[str(i)] = nn.Linear(ins, out)
        self.Linears = nn.ModuleDict(self.Linears)

    def forward(self, x):
        outs = []
        bs = x.shape[0]
        x = x.reshape((16,bs,32,6,6)) 
        for i in range(x.shape[0]):
            a = self.Linears[str(i)](x[i])
            outs.append(a)
        ou = torch.stack(outs)
        ou = ou.reshape(bs,16,32,6,16)
        ##ic(ou.shape)
        return ou


class AttenMod(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, stride=1)
        self.pool1 = nn.MaxPool2d(2, stride=2)
        self.dropout1 = nn.Dropout(0.1)
        self.conv2 = nn.Conv2d(32, 32, 3, stride=1)
        self.pool2 = nn.MaxPool2d(2, stride=2)
        self.dropout2 = nn.Dropout(0.1)
        self.expand = nn.Linear(36,64)
        self.relu1 = nn.ReLU()
        self.dropout3 = nn.Dropout(0.1)
        dicts = {}
        for i in range(16):
            dicts["M-atten-" + str(i)] = nn.MultiheadAttention(64, 4, dropout=0.2)     
        
        self.multiheads = nn.ModuleDict(dicts)
        self.atten_dropout = nn.Dropout(0.2)
        
        self.layernorm1 = nn.LayerNorm(64)
        encod = nn.TransformerEncoderLayer(d_model=2048, nhead=16)
        self.final_encoder = nn.TransformerEncoder(encod, num_layers=1)
        
        self.flin1 = nn.Linear(16*2048,512)
        self.flin2 = nn.Linear(512,128)
        self.flin3 = nn.Linear(128,25)
        self.att_mat = None
        
 

    def forward(self,t):
        
        bs = t.shape[0]

        sliced = torch.zeros((t.shape[0],16, 32, 32))
        # ic(x.shape)
        for i in range(4):
            for j in range(4):
                # print(" {}:{}, {}:{}".format(i * 32, i * 32 + 32, j * 32, j * 32 + 32))
                sliced[:,i * 4 + j] = t[:,i * 32:i * 32 + 32, j * 32:j * 32 + 32]

        x = sliced
        x = x.reshape((bs*16, 1, 32, 32))
        
        u = self.conv1(x)
        u = self.pool1(u)
        u = self.dropout1(u)

        u = self.conv2(u)
        u = self.pool2(u)
        u = self.dropout2(u)
        u = u.reshape((bs*16,32,36))
        u = self.expand(u)         
        u = self.relu1(u)
        # ic(u.shape)
        u = u.reshape((16,32,bs,64))
        attention_out =[]
        
        for i in range(len(self.multiheads)):
            att_out, _att_mat = (self.multiheads["M-atten-" + str(i)](u[i] ,u[i], u[i]))
            attention_out.append(att_out)

        attention_out_stacked = torch.stack(attention_out) 
        u = u + self.atten_dropout(attention_out_stacked)
        u = self.layernorm1(u)
        u = u.reshape((16,bs,32*64))
        u = self.final_encoder(u)         
        u = u.reshape((bs,16*2048))
        u = self.flin1(u)
        u = nn.functional.relu(u)
        u = self.flin2(u)
        u = nn.functional.relu(u)
        u = self.flin3(u)
        u = nn.functional.softmax(u,dim = -1)
        # ic(u.shape)
        return u

The following code I used for training the model.

from attensat import *
import time
import sys
import tqdm
import progressbar
import pickle
import cv2
import numpy as np
import matplotlib.pyplot as plt
import random
print("Importation Finished")

class iterdata(torch.utils.data.IterableDataset):

    def __init__(self,start,end):
        super(iterdata).__init__()
        assert end > start, "end < start"
        self.start = start 
        self.end = end

        self.videos = pickle.load(open("../saved/2FG-V.pkl","rb"))
        self.labels = pickle.load(open("../saved/2FG-L.pkl","rb"))
        print("Loading Pickle Finished")
        self._preproces()

        ic(len(self.images))
        ic(len(self.imglabels))

    def _preproces(self):
        self.images = []
        frames_count = 0
        self.imglabels = []
        lab_dic = ['teacher', 'name', 'word', 'eraser', 'result', 'memorize', 'pen', 'scale', 'paper', 'principal', 'student', 'exam', 'blackboard', 'pass', 'picture', 'education', 'college', 'university', 'pencil', 'title', 'file', 'book', 'fail', 'sentence', 'classroom'] 
        for i in range(len(self.videos)):
            v = self.videos[i]
            l = self.labels[i]
            for f in v:
                if np.count_nonzero(f) > 300:
                    self.images.append(f)
                    self.imglabels.append(lab_dic.index(l))
                frames_count+=1
        self.images = np.array(self.images)
        indices = list(range(len(self.images)))
        np.random.shuffle(indices) 
        # ic(indices[0])
        # ic(self.imglabels[0])
        shuffle_images = np.array(self.shuffle(self.images,indices))
        shuffle_labels = np.array(self.shuffle(self.imglabels,indices))


        # shuffle_images =  shuffle_images[:,np.newaxis,:,:] 

        self.X = shuffle_images[:8*len(shuffle_images)//10] 
        self.Y = shuffle_labels[:8*len(shuffle_images)//10] 
        self.VX = shuffle_images[8*len(shuffle_images)//10:] 
        self.VY = shuffle_labels[8*len(shuffle_images)//10:] 
        self.X = torch.tensor(self.X,dtype=torch.float32)
        self.VX = torch.tensor(self.VX,dtype=torch.float32)
        self.Y =  torch.tensor(self.Y,dtype=torch.long) 
        self.VY =  torch.tensor(self.VY,dtype=torch.long) 
        print("Preprocess Finished")

        
        # ic(self.X.shape,self.VX.shape) 

    def shuffle(self,x,indices):
        r = []
        for i in range(len(indices)):
            r.append(x[indices[i]])
        return r
    def __iter__(self):
        return iter(self.images[self.start:self.end])

    def __len__(self):
        return len(self.images)


def getBack(var_grad_fn,):
    # print(var_grad_fn)
    count = 0
    for n in var_grad_fn.next_functions:
        count+=1
        if n[0]:
            try:
                # tensor = getattr(n[0], 'variable')
                # print(n[0].shape)
                # print('Tensor with grad found:', tensor.shape)
                # print(' - gradient:', tensor.grad.shape)
                count+=1
                print(n[0].shape)
            except AttributeError as e:
                count+=getBack(n[0])
    # ic("Grad count",count)
    return count


def train(model,epochs):
    it = iterdata(0,24971)

    optim = torch.optim.Adam(model.parameters(),lr=0.001)
    loss_fn = torch.nn.CrossEntropyLoss()
    h_loss = []
    h_acc = []
    vh_loss = []
    vh_acc = []
    
    for e in range(epochs):
        avg_acc = 0.0
        avg_loss = 0.0
        c = 0
        model.train()
        bs = 25
        for i in tqdm.tqdm(range(0,100,bs)):
            model.train()
            optim.zero_grad()
            X = it.X[i:i+bs].clone().detach()
            Y = it.Y[i:i+bs].clone().detach()
            print(X.shape,Y.shape)
            inp_time = time.time()
            y = model(X)
            print("Input time",time.time() - inp_time)
            # ic(y.shape,Y.shape)
            loss = loss_fn(y,Y)
            avg_loss += loss.item()
            back_time = time.time()
            loss.backward()
            ic("Entering print graph")
            time.sleep(1)

            grad_count = getBack(loss.grad_fn)
            ic("Exiting print graph",grad_count)
            time.sleep(1)
            ic(time.time()-back_time)
            step_time = time.time()
            optim.step()
            # print(dir(loss))
            # loss.zero_()
            ic(time.time()-step_time)
            accuracy = (torch.argmax(y,-1)==Y).sum().float()/X.shape[0]
            print(accuracy.grad)
            avg_acc += accuracy
            c+=1
            # break
            # sys.stdout.write('\r')
            # # the exact output you're looking for:
            # sys.stdout.write("[%-10s] %d%%" % ('='*int(i*10/len(it.X))), (i*100/len(it.X)))
            # sys.stdout.flush()

        avg_loss = avg_loss/c 
        avg_acc = avg_acc/c
        ic("Epoch ",e,avg_acc,avg_loss)
        h_acc.append(avg_acc) 
        h_loss.append(avg_loss)

        if e%4 == 0:
            avg_acc = 0.0
            avg_loss = 0.0
            c = 0
            model.eval()
            bs = 25 
            for i in range(0,100,bs):
                with torch.no_grad():
                    VX = it.VX[i:i+bs]
                    VY = it.VY[i:i+bs]
                    y = model(VX)
                    # ic(y.shape,VY.shape)
                    loss = loss_fn(y,VY)
                    avg_loss += loss.item()
                    accuracy = (torch.argmax(y,-1)==VY).sum().float()/VX.shape[0]
                    avg_acc += accuracy
                    c+=1
            
            vavg_loss = avg_loss/c 
            vavg_acc = avg_acc/c
            
            print("Epoch :",e,"Validation Accuracy :",vavg_acc,"Validation Loss :",vavg_loss)
            ic(vavg_acc,vavg_loss)
            vh_acc.append(vavg_acc)
            vh_loss.append(vavg_loss)
            ic(vavg_acc,max(vh_acc))
            if vavg_acc >= max(vh_acc):
                ic("Saving model",vavg_acc)
                torch.save(model,"../models/divaten_model.pth") 
                
            plt.clf()
            plt.plot(range(len(h_loss)), h_loss)
            plt.plot(range(0, len(vh_loss) * 4, 4), vh_loss)
            plt.xlabel("epochs")
            plt.ylabel("loss")
            # plt.show()
            plt.savefig("../plots/loss_1"+".png")
            plt.clf()
            plt.plot(range(len(h_acc)), h_acc)
            plt.plot(range(0, len(vh_acc) * 4, 4), vh_acc)
            plt.xlabel("epochs")
            plt.ylabel("accuracy")
            #plt.show()
            plt.savefig("../plots/acc_1" + ".png")
            plt.clf()
            
    return model

                
if __name__ == '__main__':
    # # ds = torch.utils.data.DataLoader(it,num_workers=0)
    # # attensat = ConvModel()

    # u = torch.tensor(list(range(128*128*10)),dtype=torch.float32)
    # u = u.reshape((10,128,128))
    
    attensat = AttenMod()
    convmodel = ConvModel()
    
    # pytorch_total_params = sum(p.numel() for p in attensat.parameters()) 
    # pytorch_total_params1 = sum(p.numel() for p in convmodel.parameters()) 
    
    # # print(attensat)
    # ic(pytorch_total_params)
    # ic(pytorch_total_params1)
    # attensat(u) 
    
    trained_model = train(attensat,30)

    u = torch.tensor(list(range(128*128*1)),dtype=torch.float32)
    u = u.reshape((1,128,128))

    # attensat = Attensat()
    # conv = ConvModel()

    import time
    x = time.time()
    result = attensat(u)
    print(time.time()-x)

    
    x = time.time()
    u = u.reshape((-1,1,128,128))
    result = convmodel(u)
    print(time.time()-x)
    print(result.shape)
    
    # result1 = conv(u.reshape((1,1,128,128)))
    # print(result1.shape)

I tried all the methods I know to debug it. But I can’t find whats the reason. I train the model in macbook air , which doesnt have gpu. And I have only 8 gb ram. My question is why initial batches have little backward time but later epoch batches it increases exponentially.
This is the output I got.

Importation Finished
Loading Pickle Finished
Preprocess Finished
ic| len(self.images): 24971
ic| len(self.imglabels): 24971
  0%|                                                                                                                                  | 0/4 [00:00<?, ?it/s]torch.Size([25, 128, 128]) torch.Size([25])
Input time 0.6769578456878662
ic| "Entering print graph": 'Entering print graph'
ic| "Exiting print graph": 'Exiting print graph', grad_count: 30974
ic| time.time()-back_time: 3.169736862182617
ic| time.time()-step_time: 0.7119827270507812
None
 25%|██████████████████████████████▌                                                                                           | 1/4 [00:04<00:13,  4.63s/it]torch.Size([25, 128, 128]) torch.Size([25])
Input time 0.6140007972717285
ic| "Entering print graph": 'Entering print graph'
ic| "Exiting print graph": 'Exiting print graph', grad_count: 30974
ic| time.time()-back_time: 3.0686051845550537
ic| time.time()-step_time: 0.34681081771850586
None
 50%|█████████████████████████████████████████████████████████████                                                             | 2/4 [00:08<00:08,  4.30s/it]torch.Size([25, 128, 128]) torch.Size([25])
Input time 0.6014959812164307
ic| "Entering print graph": 'Entering print graph'
ic| "Exiting print graph": 'Exiting print graph', grad_count: 30974
ic| time.time()-back_time: 3.1198439598083496
ic| time.time()-step_time: 0.33158206939697266
None
 75%|███████████████████████████████████████████████████████████████████████████████████████████▌                              | 3/4 [00:12<00:04,  4.20s/it]torch.Size([25, 128, 128]) torch.Size([25])
Input time 0.5630521774291992
ic| "Entering print graph": 'Entering print graph'
ic| "Exiting print graph": 'Exiting print graph', grad_count: 30974
ic| time.time()-back_time: 3.139286994934082
ic| time.time()-step_time: 0.33261704444885254
None
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:16<00:00,  4.21s/it]
ic| "Epoch ": 'Epoch '
    e: 0
    avg_acc: tensor(0.0600)
    avg_loss: 3.2162408232688904
Epoch : 0 Validation Accuracy : tensor(0.0400) Validation Loss : 3.245347559452057
ic| vavg_acc: tensor(0.0400), vavg_loss: 3.245347559452057
ic| vavg_acc: tensor(0.0400), max(vh_acc): tensor(0.0400)
ic| "Saving model": 'Saving model', vavg_acc: tensor(0.0400)
2021-03-05 09:46:49.726 Python[80864:2015874] ApplePersistenceIgnoreState: Existing state will not be touched. New state will be written to (null)
  0%|                                                                                                                                  | 0/4 [00:00<?, ?it/s]torch.Size([25, 128, 128]) torch.Size([25])
Input time 0.6340761184692383
ic| "Entering print graph": 'Entering print graph'
ic| "Exiting print graph": 'Exiting print graph', grad_count: 30974
ic| time.time()-back_time: 3.136206865310669
ic| time.time()-step_time: 0.2984278202056885
None
 25%|██████████████████████████████▌                                                                                           | 1/4 [00:04<00:12,  4.10s/it]torch.Size([25, 128, 128]) torch.Size([25])
Input time 0.5628039836883545
ic| "Entering print graph": 'Entering print graph'
ic| "Exiting print graph": 'Exiting print graph', grad_count: 30974
ic| time.time()-back_time: 3.0723559856414795
ic| time.time()-step_time: 0.45653605461120605
None
 50%|█████████████████████████████████████████████████████████████                                                             | 2/4 [00:08<00:08,  4.11s/it]torch.Size([25, 128, 128]) torch.Size([25])
Input time 0.5509109497070312
ic| "Entering print graph": 'Entering print graph'
ic| "Exiting print graph": 'Exiting print graph', grad_count: 30974
ic| time.time()-back_time: 3.0851430892944336
ic| time.time()-step_time: 0.35164594650268555
None
 75%|███████████████████████████████████████████████████████████████████████████████████████████▌                              | 3/4 [00:12<00:04,  4.07s/it]torch.Size([25, 128, 128]) torch.Size([25])
Input time 0.5993270874023438
ic| "Entering print graph": 'Entering print graph'
ic| "Exiting print graph": 'Exiting print graph', grad_count: 30974
ic| time.time()-back_time: 3.1054790019989014
ic| time.time()-step_time: 0.4000682830810547
None
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:16<00:00,  4.09s/it]
ic| "Epoch ": 'Epoch '
    e: 1
    avg_acc: tensor(0.0700)
    avg_loss: 3.2153475284576416
  0%|                                                                                                                                  | 0/4 [00:00<?, ?it/s]torch.Size([25, 128, 128]) torch.Size([25])
Input time 0.5744240283966064
ic| "Entering print graph": 'Entering print graph'
ic| "Exiting print graph": 'Exiting print graph', grad_count: 30974
ic| time.time()-back_time: 3.078119993209839
ic| time.time()-step_time: 0.43111276626586914
None
 25%|██████████████████████████████▌                                                                                           | 1/4 [00:04<00:12,  4.11s/it]torch.Size([25, 128, 128]) torch.Size([25])
Input time 0.5769028663635254
ic| "Entering print graph": 'Entering print graph'
ic| "Exiting print graph": 'Exiting print graph', grad_count: 30974
ic| time.time()-back_time: 3.098628282546997
ic| time.time()-step_time: 0.41374993324279785
None
 50%|█████████████████████████████████████████████████████████████                                                             | 2/4 [00:08<00:08,  4.12s/it]torch.Size([25, 128, 128]) torch.Size([25])
Input time 0.5553209781646729
ic| "Entering print graph": 'Entering print graph'
ic| "Exiting print graph": 'Exiting print graph', grad_count: 30974
ic| time.time()-back_time: 3.076097011566162
ic| time.time()-step_time: 0.41425514221191406
None
 75%|███████████████████████████████████████████████████████████████████████████████████████████▌                              | 3/4 [00:12<00:04,  4.10s/it]torch.Size([25, 128, 128]) torch.Size([25])
Input time 0.5466268062591553
ic| "Entering print graph": 'Entering print graph'
ic| "Exiting print graph": 'Exiting print graph', grad_count: 30974
ic| time.time()-back_time: 3.147501230239868
ic| time.time()-step_time: 0.39598798751831055
None
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:16<00:00,  4.11s/it]
ic| "Epoch ": 'Epoch '
    e: 2
    avg_acc: tensor(0.0700)
    avg_loss: 3.2153475284576416
  0%|                                                                                                                                  | 0/4 [00:00<?, ?it/s]torch.Size([25, 128, 128]) torch.Size([25])
Input time 0.5708968639373779
ic| "Entering print graph": 'Entering print graph'
ic| "Exiting print graph": 'Exiting print graph', grad_count: 30974
ic| time.time()-back_time: 3.1175291538238525
ic| time.time()-step_time: 0.3312060832977295
None
 25%|██████████████████████████████▌                                                                                           | 1/4 [00:04<00:12,  4.05s/it]torch.Size([25, 128, 128]) torch.Size([25])
Input time 0.5483980178833008
ic| "Entering print graph": 'Entering print graph'
ic| "Exiting print graph": 'Exiting print graph', grad_count: 30974
ic| time.time()-back_time: 3.10416316986084
ic| time.time()-step_time: 0.3513929843902588
None
 50%|█████████████████████████████████████████████████████████████                                                             | 2/4 [00:08<00:08,  4.04s/it]torch.Size([25, 128, 128]) torch.Size([25])
Input time 0.5628242492675781
ic| "Entering print graph": 'Entering print graph'
ic| "Exiting print graph": 'Exiting print graph', grad_count: 30974
ic| time.time()-back_time: 3.451127052307129
ic| time.time()-step_time: 0.32834386825561523
None
 75%|███████████████████████████████████████████████████████████████████████████████████████████▌                              | 3/4 [00:12<00:04,  4.20s/it]torch.Size([25, 128, 128]) torch.Size([25])
Input time 0.5521800518035889
ic| "Entering print graph": 'Entering print graph'
ic| "Exiting print graph": 'Exiting print graph', grad_count: 30974
ic| time.time()-back_time: 3.150189161300659
ic| time.time()-step_time: 0.3240208625793457
None
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:16<00:00,  4.13s/it]
ic| "Epoch ": 'Epoch '
    e: 3
    avg_acc: tensor(0.0700)
    avg_loss: 3.2153475284576416
  0%|                                                                                                                                  | 0/4 [00:00<?, ?it/s]torch.Size([25, 128, 128]) torch.Size([25])
Input time 0.5504856109619141
ic| "Entering print graph": 'Entering print graph'
ic| "Exiting print graph": 'Exiting print graph', grad_count: 30974
ic| time.time()-back_time: 3.2490029335021973
ic| time.time()-step_time: 0.34058690071105957
None
 25%|██████████████████████████████▌                                                                                           | 1/4 [00:04<00:12,  4.17s/it]torch.Size([25, 128, 128]) torch.Size([25])
Input time 0.5655758380889893
ic| "Entering print graph": 'Entering print graph'
ic| "Exiting print graph": 'Exiting print graph', grad_count: 30974
ic| time.time()-back_time: 3.2228009700775146
ic| time.time()-step_time: 0.3308389186859131
None
 50%|█████████████████████████████████████████████████████████████                                                             | 2/4 [00:08<00:08,  4.16s/it]torch.Size([25, 128, 128]) torch.Size([25])
Input time 0.5511460304260254
ic| "Entering print graph": 'Entering print graph'
ic| "Exiting print graph": 'Exiting print graph', grad_count: 30974
ic| time.time()-back_time: 3.37975811958313
ic| time.time()-step_time: 0.33722710609436035
None
 75%|███████████████████████████████████████████████████████████████████████████████████████████▌                              | 3/4 [00:12<00:04,  4.22s/it]torch.Size([25, 128, 128]) torch.Size([25])
Input time 0.5787301063537598
ic| "Entering print graph": 'Entering print graph'
ic| "Exiting print graph": 'Exiting print graph', grad_count: 30974
ic| time.time()-back_time: 3.51228404045105
ic| time.time()-step_time: 0.4722568988800049
None
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:17<00:00,  4.30s/it]
ic| "Epoch ": 'Epoch '
    e: 4
    avg_acc: tensor(0.0700)
    avg_loss: 3.2153475284576416
Epoch : 4 Validation Accuracy : tensor(0.0400) Validation Loss : 3.245347559452057
ic| vavg_acc: tensor(0.0400), vavg_loss: 3.245347559452057
ic| vavg_acc: tensor(0.0400), max(vh_acc): tensor(0.0400)
ic| "Saving model": 'Saving model', vavg_acc: tensor(0.0400)
  0%|                                                                                                                                  | 0/4 [00:00<?, ?it/s]torch.Size([25, 128, 128]) torch.Size([25])
Input time 0.5875420570373535
ic| "Entering print graph": 'Entering print graph'
ic| "Exiting print graph": 'Exiting print graph', grad_count: 30974
ic| time.time()-back_time: 3.7260007858276367
ic| time.time()-step_time: 0.4166100025177002
None
 25%|██████████████████████████████▌                                                                                           | 1/4 [00:04<00:14,  4.76s/it]torch.Size([25, 128, 128]) torch.Size([25])
Input time 0.5677318572998047
ic| "Entering print graph": 'Entering print graph'
ic| "Exiting print graph": 'Exiting print graph', grad_count: 30974
ic| time.time()-back_time: 3.857970952987671
ic| time.time()-step_time: 0.3497660160064697
None
 50%|█████████████████████████████████████████████████████████████                                                             | 2/4 [00:09<00:09,  4.79s/it]torch.Size([25, 128, 128]) torch.Size([25])
Input time 0.5872910022735596
ic| "Entering print graph": 'Entering print graph'
ic| "Exiting print graph": 'Exiting print graph', grad_count: 30974
ic| time.time()-back_time: 4.526216983795166
ic| time.time()-step_time: 0.32611727714538574
None
 75%|███████████████████████████████████████████████████████████████████████████████████████████▌                              | 3/4 [00:15<00:05,  5.10s/it]torch.Size([25, 128, 128]) torch.Size([25])
Input time 0.5695152282714844
ic| "Entering print graph": 'Entering print graph'
ic| "Exiting print graph": 'Exiting print graph', grad_count: 30974
ic| time.time()-back_time: 5.2922961711883545
ic| time.time()-step_time: 0.3315248489379883
None
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:21<00:00,  5.32s/it]
ic| "Epoch ": 'Epoch '
    e: 5
    avg_acc: tensor(0.0700)
    avg_loss: 3.2153475284576416
  0%|                                                                                                                                  | 0/4 [00:00<?, ?it/s]torch.Size([25, 128, 128]) torch.Size([25])
Input time 0.5940840244293213
ic| "Entering print graph": 'Entering print graph'
ic| "Exiting print graph": 'Exiting print graph', grad_count: 30974
ic| time.time()-back_time: 6.436051845550537
ic| time.time()-step_time: 0.32418203353881836
None
 25%|██████████████████████████████▌                                                                                           | 1/4 [00:07<00:22,  7.39s/it]torch.Size([25, 128, 128]) torch.Size([25])
Input time 0.555290937423706
ic| "Entering print graph": 'Entering print graph'
ic| "Exiting print graph": 'Exiting print graph', grad_count: 30974
ic| time.time()-back_time: 8.099836111068726
ic| time.time()-step_time: 0.3228638172149658
None
 50%|█████████████████████████████████████████████████████████████                                                             | 2/4 [00:16<00:16,  8.34s/it]torch.Size([25, 128, 128]) torch.Size([25])
Input time 0.552196741104126
ic| "Entering print graph": 'Entering print graph'
ic| "Exiting print graph": 'Exiting print graph', grad_count: 30974
ic| time.time()-back_time: 10.675221920013428
ic| time.time()-step_time: 0.3215208053588867
None
 75%|███████████████████████████████████████████████████████████████████████████████████████████▌                              | 3/4 [00:27<00:09,  9.82s/it]torch.Size([25, 128, 128]) torch.Size([25])
Input time 0.5509300231933594
 75%|███████████████████████████████████████████████████████████████████████████████████████████▌                              | 3/4 [00:42<00:14, 14.11s/it]

I use tqdm to estimate the time. If u see, the iteration per second increases at the last parts of the output.
Thank you.

Could you check if you might be running out of memory and your system might be using the swap?

Yeah man, thanks. I reduced the batch size to 10 and still the backward time increased to a maximum of 1 minute and after it reduced to normal. But why does this happen? Why it happens after a some epochs

I tried in another system with 16 GB ram, still at a point the time increases so much. Is this due to the for loop in the model?. Did I write the forward pass efficiently?. Thank you very much.

If the memory usage increases in each iteration, you might accidentally store the computation graph by e.g. appending the output of loss to a list without detaching it.
However, I cannot spot the error in your current code and would recommend to remove all places where tensors could be stored in some kind of container to further isolate this issue.
If the “pure” training loop also increases the memory, it might be a valid issue.

Is the below code is wrong?

            print("Input time",time.time() - inp_time)
            # ic(y.shape,Y.shape)
            loss = loss_fn(y,Y)
            avg_loss += loss.item()
            back_time = time.time()
            loss.backward()
'''
Adding the loss.item() to a variable is wrong?
Thank you

No, it’s the correct way to detach the tensor and only store the Python float.
However, removing all additional code and only executing the training loop could further narrow down the issue and either point towards a currently unspotted error in the code or an internal bug, which might cause the memory increase.

Now I am running the code in a system with 16 gb ram. It runs smoothly there. Thanks man.

Are you still seeing increased memory usage during the training or is it increasing up to a specific level and stops then?

In new system, the memory usage is constant through out the training. Not increasing anywhere. And the computation time for loss.backward() is too constant.