[resolved] Out of memory in the medium of training, always happen when training the same batch ID

I meet a out of memory problem:
File “/home/yfwu/ctavatar/tools-pytorch/reg_gan3d.py”, line 171, in train
errG.backward()
File “/home/yfwu/pytorch/local/lib/python2.7/site-packages/torch/autograd/variable.py”, line 146, in backward
self._execution_engine.run_backward((self,), (gradient,), retain_variables)
RuntimeError: cuda runtime error (2) : out of memory at /b/wheel/pytorch-src/torch/lib/THC/generic/THCStorage.cu:66

I tried several times and it happened at the same time(after 370 batch training).
This didn’t happen before I added the line : " netG.load_state_dict(torch.load(’…/models/netG_epoch_6.pth’))", and I can train successfully without pre-trained model.
I am wondering if something are saved because of this line. Many thanks in advance.


Following is my code:

def train(epoch):
netD.train()
netG.load_state_dict(torch.load('../models/netG_epoch_6.pth'))
netG.train()
epoch_loss_D = 0
epoch_loss_G = 0
for batch_idx, (ins, tgs) in enumerate(train_dataloader):
    ins = Variable(ins.cuda())
    tgs = Variable(tgs.cuda())
    ############################
    # (1) Update D network: maximize log(D(tgs)) + log(1 - D(G(ins)))
    ###########################
    # train with real
    optimizerD.zero_grad()
    output_D = netD(tgs)
    label.resize_(1).fill_(real_label)
    labelv = Variable(label)
    errD_real = criterion_GAN(output_D, labelv)
    errD_real.backward()
   
    # train with fake
    fake = netG(ins)
    labelv = Variable(label.fill_(fake_label))
    output_D = netD(fake.detach())
    errD_fake = criterion_GAN(output_D, labelv)
    errD_fake.backward()
    errD = errD_real + errD_fake
    optimizerD.step()

    ############################
    # (2) Update G network: maximize log(D(G(z)))
    ###########################
    optimizerG.zero_grad()
    labelv = Variable(label.fill_(real_label))
    output_D = netD(fake)
    errG = criterion_GAN(output_D, labelv)
    errG.backward()
    optimizerG.step()

My problem has been solved. See:

Thanks pytorch forum!