Hello, everyone!
I train the GAN network. But, i have some trouble with memory.
At the begining of training, memory usage is around 20%.
However, after 20 epochs, the memory usage is over 95%!!!
Could you check my code??
(this is my training code)
def trainGenerator(self, data, target, data_pred, numPred, latent_size):
self.Encoder.train()
self.Generator.train()
self.Discriminator.train()
batch_size = data.shape[0]
# reset hidden state and gradient of generator
self.optimGenerator.zero_grad()
# concatenate latent, encoded and target
encoded = self.Encoder(data)
target_matrix = utils.targetMatrix(numPred, target).to(self.device)
latent = utils.generateLatent(batch_size, numPred, latent_size).to(self.device)
z = torch.cat([latent, encoded, target_matrix], dim=2).to(self.device)
z_fix = torch.cat([torch.zeros(latent.shape).to(self.device), encoded, target_matrix], dim=2).to(self.device)
# train fake sample as real (1)
y = Variable(torch.ones(batch_size, 1).to(self.device))
# generate sample
fake = self.Generator(z)
fake_fix = self.Generator(z_fix)
# concatenate fake sample and target
fakeD = torch.cat([fake, encoded, target_matrix], dim=2)
# compute loss
D_fake = self.Discriminator(fakeD)
weight = torch.ones([data_pred.shape[1], 1]).to(self.device)
weight = torch.cat([weight, 15 * weight], dim=1)
weight_true = data_pred[:, :, :2] * weight[np.newaxis, :, :]
weight_fake = fake_fix[:, :, :2] * weight[np.newaxis, :, :]
lossL2 = torch.sqrt(self.criterionL2(weight_true, weight_fake)) # for reconstruction
lossFake = self.criterionLS(D_fake, y) # for Generator
self.lossGenerator = lossFake + self.ratio * lossL2
# back propagation
self.lossGenerator.backward()
self.optimGenerator.step()
# delete intermediate variables to reduce memory usage
del z, z_fix, y, fake, fake_fix, fakeD, weight, weight_true, weight_fake
return lossL2.item(), lossFake.item(), torch.mean(D_fake.detach())
def trainDiscriminator(self, data, target, data_pred, numPred, latent_size):
self.Encoder.train()
self.Generator.train()
self.Discriminator.train()
batch_size = data.shape[0]
# reset hidden state and gradient of discriminator
self.optimDiscriminator.zero_grad()
# make label for training -> discriminate real image, set y to 1
y_real = torch.ones(batch_size, 1, device=self.device)
# make label for training -> discriminate fake image, set y to 0
y_fake = torch.zeros(batch_size, 1, device=self.device)
# Train real data, concatenate real data, encoded, and target_matrix
target_matrix = utils.targetMatrix(numPred, target).to(self.device)
encoded = self.Encoder(data)
realD = torch.cat([data_pred.to(self.device), encoded, target_matrix], dim=2).to(self.device)
# coumpute real data loss
D_real = self.Discriminator(realD)
lossReal = self.criterionLS(D_real, y_real)
# Train fake data, concatenate latent, encoded and target_matrix
latent = utils.generateLatent(batch_size, numPred, latent_size).to(self.device)
z = torch.cat([latent, encoded, target_matrix], dim=2).to(self.device)
# generate smaple
fake = self.Generator(z)
# concatenate fake image and target
fakeD = torch.cat([fake, encoded, target_matrix], dim=2).to(self.device)
# compute fake data loss
D_fake = self.Discriminator(fakeD)
lossFake = self.criterionLS(D_fake, y_fake)
# Discriminator loss
self.lossDiscriminator = (lossReal + lossFake) * 0.5
# back propagation
self.lossDiscriminator.backward()
self.optimDiscriminator.step()
# Release memory
del y_real, y_fake, target_matrix, encoded, realD, latent, z, fake, fakeD
torch.cuda.empty_cache()
return lossReal.item(), lossFake.item(), torch.mean(D_real).item(), torch.mean(D_fake).item()