I’ve met a strange memory leak when I tried to implement “Improved Training of Wasserstein GANs”. I’m getting OOM in the middle of second epoch both on CPU and GPU. Memory usage seems to increase after each batch, the profiling of CPU version points on the for loop over dataloader. Here is the kinda minimal example:
from __future__ import print_function
import argparse
import os
import random
import torch
import torch.backends.cudnn as cudnn
import torch.nn.parallel
import torch.optim as optim
import torch.utils.data
import torchvision.datasets as dset
import torchvision.transforms as transforms
from torch import autograd
from torch import nn
from torch.autograd import Variable
parser = argparse.ArgumentParser()
parser.add_argument('--dataroot', required=True, help='path to dataset')
parser.add_argument('--workers', type=int, help='number of data loading workers', default=4)
parser.add_argument('--batchSize', type=int, default=128, help='input batch size')
parser.add_argument('--imageSize', type=int, default=32, help='the height / width of the input image to network')
parser.add_argument('--nz', type=int, default=128, help='size of the latent z vector')
parser.add_argument('--num_gen_filters', type=int, default=32, help='# of gen filters in first conv layer') # 64
parser.add_argument('--num_disc_filters', type=int, default=32, help='# of discrim filters in first conv layer')
parser.add_argument('--niter', type=int, default=25, help='number of epochs to train for')
parser.add_argument('--lr', type=float, default=1e-4, help='learning rate, default=0.0002')
parser.add_argument('--beta1', type=float, default=0.5, help='beta1 for adam. default=0.5')
parser.add_argument('--cuda', action='store_true', default=False, help='enables cuda')
parser.add_argument('--ngpu', type=int, default=1, help='number of GPUs to use')
parser.add_argument('--netG', default='', help="path to netG (to continue training)")
parser.add_argument('--netD', default='', help="path to netD (to continue training)")
parser.add_argument('--outf', default='./result', help='folder to output images and model checkpoints')
parser.add_argument('--manualSeed', type=int, default=1, help='manual seed')
opt = parser.parse_args()
torch.backends.cudnn.benchmark = True
class Generator(nn.Module):
def __init__(self, num_gen_filters, nz=128, num_channels=3):
super(Generator, self).__init__()
self.num_gen_filters = num_gen_filters
self.preprocess = nn.Sequential(
nn.Linear(nz, 4 * 4 * 4 * num_gen_filters),
nn.BatchNorm2d(4 * 4 * 4 * num_gen_filters),
nn.ReLU(True),
)
self.block1 = nn.Sequential(
nn.ConvTranspose2d(4 * num_gen_filters, 2 * num_gen_filters, 2, stride=2),
nn.BatchNorm2d(2 * num_gen_filters),
nn.ReLU(True),
)
self.block2 = nn.Sequential(
nn.ConvTranspose2d(2 * num_gen_filters, num_gen_filters, 2, stride=2),
nn.BatchNorm2d(num_gen_filters),
nn.ReLU(True),
)
self.deconv_out = nn.ConvTranspose2d(num_gen_filters, num_channels, 2, stride=2)
self.tanh = nn.Tanh()
def forward(self, input):
output = self.preprocess(input)
output = output.view(-1, 4 * self.num_gen_filters, 4, 4)
output = self.block1(output)
output = self.block2(output)
output = self.deconv_out(output)
output = self.tanh(output)
return output.view(-1, 3, 32, 32)
class Discriminator(nn.Module):
def __init__(self, num_disc_filters, num_channels=3):
super(Discriminator, self).__init__()
self.num_disc_filters = num_disc_filters
self.main = nn.Sequential(
nn.Conv2d(num_channels, num_disc_filters, 3, 2, padding=1),
nn.LeakyReLU(),
nn.Conv2d(num_disc_filters, 2 * num_disc_filters, 3, 2, padding=1),
nn.LeakyReLU(),
nn.Conv2d(2 * num_disc_filters, 4 * num_disc_filters, 3, 2, padding=1),
nn.LeakyReLU(),
)
self.linear = nn.Linear(4 * 4 * 4 * num_disc_filters, 1)
def forward(self, input):
output = self.main(input)
output = output.view(-1, 4 * 4 * 4 * self.num_disc_filters)
output = self.linear(output)
return output
def main():
print(opt)
nz = opt.nz
try:
os.makedirs(opt.outf)
except OSError:
pass
if opt.manualSeed is None:
opt.manualSeed = random.randint(1, 10000)
print("Random Seed: ", opt.manualSeed)
random.seed(opt.manualSeed)
torch.manual_seed(opt.manualSeed)
if opt.cuda:
torch.cuda.manual_seed_all(opt.manualSeed)
cudnn.benchmark = True
if torch.cuda.is_available() and not opt.cuda:
print("WARNING: You have a CUDA device, so you should probably run with --cuda")
dataset = dset.CIFAR10(root=opt.dataroot, download=True, transform=transforms.Compose(
[transforms.Resize(opt.imageSize), transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), ]))
dataloader = torch.utils.data.DataLoader(dataset, batch_size=opt.batchSize,
shuffle=True, num_workers=int(opt.workers))
netG = Generator(opt.num_gen_filters, nz)
print(netG)
netD = Discriminator(opt.num_disc_filters)
print(netD)
noise = torch.FloatTensor(opt.batchSize, nz, 1, 1)
fixed_noise = torch.FloatTensor(opt.batchSize, nz).normal_(0, 1)
if opt.cuda:
netD.cuda()
netG.cuda()
noise, fixed_noise = noise.cuda(), fixed_noise.cuda()
# setup optimizer
optimizerD = optim.Adam(netD.parameters(), lr=opt.lr, betas=(opt.beta1, 0.9))
for epoch in range(opt.niter):
for _, data in enumerate(dataloader, 0):
for p in netG.parameters(): # reset requires_grad
p.requires_grad = False # they are set to False below in netG update
# train with real
netD.zero_grad()
real_cpu, _ = data
batch_size = real_cpu.size(0)
if opt.cuda:
real_cpu = real_cpu.cuda()
noise.resize_(batch_size, nz).normal_(0, 1)
noisev = Variable(noise, volatile=True)
fake = Variable(netG(noisev).data)
interpolates = 0.5 * real_cpu + 0.5 * fake.data
if opt.cuda:
interpolates = interpolates.cuda()
interpolates = autograd.Variable(interpolates, requires_grad=True)
disc_interpolates = netD(interpolates)
gradients = autograd.grad(outputs=disc_interpolates, inputs=interpolates,
grad_outputs=torch.ones(
disc_interpolates.size()).cuda() if opt.cuda else torch.ones(
disc_interpolates.size()), create_graph=True, only_inputs=True)[0]
gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean() * 10
gradient_penalty.backward()
optimizerD.step()
if __name__ == '__main__':
main()
Is there some problem with the code or it’s a bug?