GPU memory use increases for each iteration

Hello.

I am writing codes to train neural networks.
When training with a small size dataset, there’s no problem, however, when training with a large dataset, the system says “RuntimeError: CUDA error: out of memory”.

In my understanding, GPU memory use isn’t influenced by the size of the dataset since Pytorch load and store data for each iteration using indices. So, why is this happening?

This is the dataloader code.

        import torch.utils.data
        from preprocessing import *
        from torchvision import transforms
        from config import config
        import torch.nn.functional as F
        from scipy import ndimage
        import numpy as np
        from os.path import  join
        from PIL import Image
        from plyfile import PlyData
        pre=Prepro()
    
    
        class PRDataset(torch.utils.data.Dataset):
            def __init__(self,tem_dir,ref_dir, visual=False, transform=None):
                self.ref_dir = ref_dir
                self.tem_dir = tem_dir
                self.transform=transform
                self.ref_names = os.listdir(ref_dir)
                self.ref_names.sort(key=self.natural_keys)
    
                self.tem_names = os.listdir(tem_dir)
                self.tem_names = [self.tem_names[0] for i in range(len(self.ref_names))]
                self.normalize = transforms.Normalize(
                        mean=[0.485, 0.456, 0.406],
                        std=[0.229, 0.224, 0.225]
                        )
    
                self.trans_in = transforms.Compose([
                                           transforms.ToTensor(),
                                           self.normalize])
    
                self.trans_out = transforms.Compose([
                                            transforms.ToTensor()])
                self.visual=visual
    
            def atoi(self,text):
                return int(text) if text.isdigit() else text
    
            def natural_keys(self,text):
                '''
                alist.sort(key=natural_keys) sorts in human order
                http://nedbatchelder.com/blog/200712/human_sorting.html
                (See Toothy's implementation in the comments)
                '''
                return [self.atoi(c) for c in re.split('(\d+)', text) ]
    
            def read_ply(self,file_path):
                plydata = PlyData.read(file_path).elements[0].data#[:,0:3]
   
                return np.array(plydata.tolist())[:,0:3]
    
    
            def __len__(self):
                return len(self.ref_names)
    
            def __getitem__(self, idx):
                tem_name = '{}'.format(self.tem_names[idx])
                tem_fullname = join(self.tem_dir, tem_name)
                ref_name = '{}'.format(self.ref_names[idx])
                ref_fullname = join(self.ref_dir, ref_name)
    
                tem = self.read_ply(tem_fullname)
                ref = self.read_ply(ref_fullname)

                if self.visual:
                    return [tem, ref, ref_name]
                else:
                    return [tem, ref]

Also this is the training code.

from model import *
from save_ply import *
import torch.utils.data
from dataloader import *
import time
from  datasets import *
import utils
from loss_functions import Loss_Funcs
from config import config
from tensorboardX import SummaryWriter
from loss_functions import Loss_Funcs
os.environ["CUDA_VISIBLE_DEVICES"]="3"
use_gpu=torch.cuda.is_available()
prepro=Prepro()
save=Save()
lossf = Loss_Funcs()

path_dic={
    "t_T":'../../dataset/surface/train/template/',
    "t_V": '../../dataset/surface/val/template/',
    "r_T": '../../dataset/surface/train/reference/',
    "r_V": '../../dataset/surface/val/reference/'
}


"""
Data distribution
"""
train = PRDataset(path_dic["t_T"],path_dic["r_T"], transform=True)
val = PRDataset(path_dic["t_V"],path_dic["r_V"],transform=True)
train_loader = torch.utils.data.DataLoader(train, batch_size=config.batch, shuffle=True,num_workers=4)
valid_loader = torch.utils.data.DataLoader(val, batch_size=config.batch, shuffle=True,num_workers=4)
dloaders={"train":train_loader,"valid":valid_loader}


train_num = train.__len__()
val_num = val.__len__()
print
print '<loaded data>'
print 'all data :',train_num+val_num
print 'train data :',train_num
print 'validation data :',val_num


def G(G_models,tmp,ref):

    n_pt = tmp.shape[1]

    out1_1= G_models[0](torch.transpose(tmp,1,2))
    out1_2 = G_models[1](out1_1)
    out1_2 = torch.max(out1_2, 2, keepdim=True)[0]

    out2_1 = G_models[2](torch.transpose(ref,1,2))
    out2_2 = G_models[3](out2_1)
    out2_2 = torch.max(out2_2, 2, keepdim=True)[0]

    out1_2 = out1_2.squeeze().repeat(n_pt, 1)
    out2_2 = out2_2.squeeze().repeat(n_pt, 1)
    out1_1 = torch.transpose(out1_1,1,2).squeeze()

    out_con = torch.transpose(torch.cat((out1_1, out1_2, out2_2), 1),0,1)
    out3_1 = torch.transpose(G_models[4](out_con.view(1,-1,n_pt)),1,2)

    return tmp + out3_1

def D(D_models, input):

    out_d1 = D_models[0](torch.transpose(input,1,2))
    out_d2 = D_models[1](out_d1)

    out_d2 = torch.transpose(out_d2, 0, 1)
    out_d2 = torch.max(out_d2, 2, keepdim=True)[0]

    d =  D_models[2](out_d2.view(1,-1)).squeeze()
    return out_d2, d

"""
Training
"""
def train_model(criterionGAN,G_models,D_models, dataloaders,num_epoch,G_opt,D_opt):
    print
    print '<training start>'
    writer_grad = SummaryWriter('./log/epoch_'+str(num_epoch)+'/grad')
    for epoch in range(num_epoch):

            since = time.time()
            running_train_Gloss = 0.0
            running_train_p_loss = 0.0
            running_train_Dloss = 0.0
            running_train_G_totalloss = 0.0
            running_valid_G_totalloss=0.0
            running_valid_Gloss = 0.0
            running_valid_p_loss = 0.0
            running_valid_Dloss = 0.0

            #for phase in ["train", "valid"]:
            for phase in ["train"]:
                if phase == "train":
                    for model in G_models: model.train(True)
                    for model in D_models: model.train(True)
                else:

                    for model in G_models: model.train(False)
                    for model in D_models: model.train(False)

                for i, (tmp,ref) in enumerate(dataloaders[phase]):

                    if use_gpu:
                        tmp = Variable(tmp.cuda()).float()
                        ref = Variable(ref.cuda()).float()
                        one = Variable(torch.ones(tmp.shape[0]), requires_grad=False).cuda()
                        zero = Variable(torch.zeros(tmp.shape[0]), requires_grad=False).cuda()

                    else:
                        tmp = Variable(tmp).float()
                        ref = Variable(ref).float()
                        one = Variable(torch.ones(tmp.shape[0]), requires_grad=False)
                        zero = Variable(torch.zeros(tmp.shape[0]), requires_grad=False)

                    ############################
                    # (1) Update D network
                    ###########################
                    for opt in D_opt: opt.zero_grad()
                    #for opt in G_opt: opt.zero_grad()


                    G_out = G(G_models,tmp,ref)

                    _,d_fake = D(D_models,G_out)
                    D_fake_loss = criterionGAN(d_fake,zero)

                    _,d_real = D(D_models,ref)
                    D_real_loss = criterionGAN(d_real,one)

                    D_adv_loss = (D_fake_loss + D_real_loss) * 0.5

                    if phase == "train":
                        D_adv_loss.backward(retain_graph=True)
                        #for p in D_models[1].parameters():
                            #print p.grad
                        for opt in D_opt: opt.step()

                    ############################
                    # Update G network
                    ###########################
                    #for opt in D_opt: opt.zero_grad()
                    for opt in G_opt: opt.zero_grad()

                    feature_fake, d_fake2= D(D_models,G_out)
                    feature_real, _ = D(D_models,ref)

                    G_adv_loss = criterionGAN(d_fake2, one)
                    p_loss = config.p_weight*torch.norm(feature_fake - feature_real, p=2)
                    G_loss = G_adv_loss + p_loss

                    if phase == "train":
                        G_loss.backward(retain_graph=True)

                        for opt in G_opt: opt.step()

                    if phase=="train":
                        # accumulate the loss
                        running_train_Gloss += G_adv_loss.item()
                        running_train_p_loss += p_loss.item()
                        running_train_G_totalloss += G_adv_loss.item() + p_loss.item()
                        running_train_Dloss += D_adv_loss

                    else:
                        running_valid_Gloss += G_adv_loss.item()
                        running_valid_p_loss += p_loss.item()
                        running_valid_G_totalloss += G_adv_loss.item() +  p_loss.item()
                        running_valid_Dloss += D_adv_loss


            if epoch!= 0:

                    epoch_time=time.time()-since
                    print "Time:{:f} Epoch [{}/{}] Train G_Total: {:.4f} (G_p: {:.4f} G_adv: {:.4f}) D : {:.4f} |Val G_Total: {:.4f} (G_p: {:.4f} G_adv: {:.4f}) D : {:.4f}".format(
                                epoch_time,
                                epoch,
                                num_epoch,
                                float(running_train_G_totalloss)/train_num,#8,
                                float(running_train_p_loss) / train_num,  # 8,
                                float(running_train_Gloss) / train_num,  # 8,
                                float(running_train_Dloss) / train_num,  # 8,

                                float(running_valid_G_totalloss)/val_num,#8,
                                float(running_valid_p_loss) / val_num,  # 8,
                                float(running_train_Gloss) / val_num,  # 8,
                                float(running_valid_Dloss) / val_num,  # 8,
                                )

            if epoch%10==0:
                # save trained model
                for i,model in enumerate(G_models): torch.save(model.state_dict(), "trained_model/TMP_net"+str(i+1)+"_" + str(epoch) + ".pkl")
                for i,model in enumerate(D_models): torch.save(model.state_dict(), "trained_model/TMP_net"+str(i+6)+"_" + str(epoch) + ".pkl")
    writer_grad.close()
    print
    print('<Finished Training>')
    return

if config.use_gpu:
    TMP_net1=TMP_net1().cuda()
    TMP_net2=TMP_net2().cuda()
    TMP_net3=TMP_net3().cuda()
    TMP_net4=TMP_net4().cuda()
    TMP_net5=TMP_net5().cuda()
    TMP_net6=TMP_net6().cuda()
    TMP_net7=TMP_net7().cuda()
    TMP_net8=TMP_net8().cuda()
else:
    TMP_net1=TMP_net1()
    TMP_net2=TMP_net2()
    TMP_net3=TMP_net3()
    TMP_net4=TMP_net4()
    TMP_net5=TMP_net5()
    TMP_net6=TMP_net6()
    TMP_net7=TMP_net7()
    TMP_net8=TMP_net8()

lr = 0.0001
TMP_net1_optimizer = torch.optim.Adam(TMP_net1.parameters(), lr=lr)
TMP_net2_optimizer = torch.optim.Adam(TMP_net2.parameters(), lr=lr)
TMP_net3_optimizer = torch.optim.Adam(TMP_net3.parameters(), lr=lr)
TMP_net4_optimizer = torch.optim.Adam(TMP_net4.parameters(), lr=lr)
TMP_net5_optimizer = torch.optim.Adam(TMP_net5.parameters(), lr=lr)
TMP_net6_optimizer = torch.optim.Adam(TMP_net6.parameters(), lr=lr)
TMP_net7_optimizer = torch.optim.Adam(TMP_net7.parameters(), lr=lr)
TMP_net8_optimizer = torch.optim.Adam(TMP_net8.parameters(), lr=lr)


G_models = [TMP_net1,TMP_net2,TMP_net3,TMP_net4,TMP_net5]
D_models = [TMP_net6,TMP_net7,TMP_net8]
G_opt = [TMP_net1_optimizer,TMP_net2_optimizer,TMP_net3_optimizer,TMP_net4_optimizer,TMP_net5_optimizer]
D_opt = [TMP_net6_optimizer,TMP_net7_optimizer,TMP_net8_optimizer]

utils.count_model_params(G_models)
utils.count_model_params(D_models)
criterionGAN = nn.BCELoss()
criterionMSE = nn.MSELoss()
start_time=time.time()
model=train_model(criterionGAN,G_models,D_models,dloaders,num_epoch=config.num_epoch,G_opt=G_opt,D_opt=D_opt)



1 Like

Hi,

Aren’t you missing a .item() when you accumulate D_adv_loss ? I think that running_train_Dloss actually has the graph of all previous iterations here.

2 Likes

Thank you so so much! It was exactly the cause :slight_smile:

I was also having the same problem, but for me it was:

with torch.autograd.set_detect_anomaly(True):
    train(hyperparameters)

I did that to check if my gradients are exploding or not, then I forgot to remove that context manager.

It would be helpful if someone can explain why this was happening. Thanks