RuntimeError: Trying to backward through the graph a second time

Every loop i will do forward and then backward only once, why still the Error trigger:

Traceback (most recent call last):
  File "toy_layer.py", line 635, in <module>
    main()
  File "toy_layer.py", line 557, in main
    g_grads = torch.autograd.grad(l2, g_vars)
  File "/hdd1/liangqu/conda/lib/python3.6/site-packages/torch/autograd/__init__.py", line 144, in grad
    inputs, allow_unused)
RuntimeError: Trying to backward through the graph a second time, but the buffers have already been freed. Specify retain_graph=True when calling backward the first time.

Here is my source code:

import  torch, math
import  numpy as np
import  torch.nn as nn
import  torch.optim as optim
from    torch.utils.data import DataLoader, TensorDataset
from    torchvision import datasets, transforms
import  visdom
from    torch.nn import functional as F
import  gc







class Generator:

	def __init__(self, z_dim, c_dim, device):

		# according to Pytorch w/b format, w = [out_dim, in_dim]
		# b = [out_dim]
		self.vars = [
			# [z+c, 128]
			torch.ones(128, z_dim + c_dim, requires_grad=True, device=device),
			torch.zeros(128, requires_grad=True, device=device),
			torch.rand(128, requires_grad=True, device=device),
			torch.zeros(128, requires_grad=True, device=device),
			# [128, 256]
			torch.ones(256, 128, requires_grad=True, device=device),
			torch.zeros(256, requires_grad=True, device=device),
			torch.rand(256, requires_grad=True, device=device),
			torch.zeros(256, requires_grad=True, device=device),
			# [256, 512]
			torch.ones(512, 256, requires_grad=True, device=device),
			torch.zeros(512, requires_grad=True, device=device),
			torch.rand(512, requires_grad=True, device=device),
			torch.zeros(512, requires_grad=True, device=device),
			# [512, 1024]
			torch.ones(1024, 512, requires_grad=True, device=device),
			torch.zeros(1024, requires_grad=True, device=device),
			torch.rand(1024, requires_grad=True, device=device),
			torch.zeros(1024, requires_grad=True, device=device),
			# [1024, 28*28]
			torch.ones(28 * 28, 1024, requires_grad=True, device=device),
			torch.zeros(28 * 28, requires_grad=True, device=device),
			torch.rand(28 * 28, requires_grad=True, device=device),
			torch.zeros(28 * 28, requires_grad=True, device=device),
		]

		# moving mean and variance for normalization
		# no gradients needed.
		self.bns = [
			torch.zeros(128, device=device),
			torch.ones(128, device=device),

			torch.zeros(256, device=device),
			torch.ones(256, device=device),

			torch.zeros(512, device=device),
			torch.ones(512, device=device),

			torch.zeros(1024, device=device),
			torch.ones(1024, device=device),

			torch.zeros(28 * 28, device=device),
			torch.ones(28 * 28, device=device),

		]


	def init_weight(self, vars=None):
		"""
		init vars and self.bns
		:param vars:
		:return:
		"""

		if vars is None:
			vars = self.vars

		vars_idx = 0

		weight, bias = vars[vars_idx], vars[vars_idx + 1]
		stdv = 1. / math.sqrt(weight.size(1))
		weight.uniform_(-stdv, stdv)
		bias.uniform_(-stdv, stdv)
		# nn.init.xavier_uniform_(weight)
		# bias.data.fill_(0.01)
		weight, bias = vars[vars_idx + 2], vars[vars_idx + 3]
		weight.uniform_()
		bias.zero_()
		vars_idx += 4


		weight, bias = vars[vars_idx], vars[vars_idx + 1]
		stdv = 1. / math.sqrt(weight.size(1))
		weight.uniform_(-stdv, stdv)
		bias.uniform_(-stdv, stdv)
		# nn.init.xavier_uniform_(weight)
		# bias.data.fill_(0.01)
		weight, bias = vars[vars_idx + 2], vars[vars_idx + 3]
		weight.uniform_()
		bias.zero_()
		vars_idx += 4


		weight, bias = vars[vars_idx], vars[vars_idx + 1]
		stdv = 1. / math.sqrt(weight.size(1))
		weight.uniform_(-stdv, stdv)
		bias.uniform_(-stdv, stdv)
		# nn.init.xavier_uniform_(weight)
		# bias.data.fill_(0.01)
		weight, bias = vars[vars_idx + 2], vars[vars_idx + 3]
		weight.uniform_()
		bias.zero_()
		vars_idx += 4


		weight, bias = vars[vars_idx], vars[vars_idx + 1]
		stdv = 1. / math.sqrt(weight.size(1))
		weight.uniform_(-stdv, stdv)
		bias.uniform_(-stdv, stdv)
		# nn.init.xavier_uniform_(weight)
		# bias.data.fill_(0.01)
		weight, bias = vars[vars_idx + 2], vars[vars_idx + 3]
		weight.uniform_()
		bias.zero_()
		vars_idx += 4


		weight, bias = vars[vars_idx], vars[vars_idx + 1]
		stdv = 1. / math.sqrt(weight.size(1))
		weight.uniform_(-stdv, stdv)
		bias.uniform_(-stdv, stdv)
		# nn.init.xavier_uniform_(weight)
		# bias.data.fill_(0.01)
		weight, bias = vars[vars_idx + 2], vars[vars_idx + 3]
		weight.uniform_()
		bias.zero_()
		vars_idx += 4

		# zero mean and one variance.
		for i in range(len(self.bns) // 2 ):
			# mean
			self.bns[i].zero_()
			# variance
			self.bns[2 * i + 1].fill_(1)


	def forward(self, z, c, vars):
		"""

		:param z:
		:param c:
		:param vars:
		:return:
		"""

		vars_idx, bns_idx = 0, 0

		# [b, z_dim] + [b, c_dim] => [b, new_dim]
		x = torch.cat([z, c], dim=1)

		# [b, z+c] => [b, 128]
		x = F.linear(x, vars[vars_idx], vars[vars_idx + 1])
		x = F.batch_norm(x, self.bns[bns_idx + 0], self.bns[bns_idx + 1],
		                 weight=vars[vars_idx + 2], bias= vars[vars_idx + 3],
		                 training=True, momentum=0.1)

		x = F.leaky_relu(x, 0.2)
		vars_idx += 4
		bns_idx += 2

		# [b, 128] => [b, 256]
		x = F.linear(x, vars[vars_idx], vars[vars_idx + 1])
		x = F.batch_norm(x, self.bns[bns_idx + 0], self.bns[bns_idx + 1],
		                 weight=vars[vars_idx + 2], bias= vars[vars_idx + 3],
		                 training=True, momentum=0.1)

		x = F.leaky_relu(x, 0.2)
		vars_idx += 4
		bns_idx += 2

		# [b, 256] => [b, 512]
		x = F.linear(x, vars[vars_idx], vars[vars_idx + 1])
		x = F.batch_norm(x, self.bns[bns_idx + 0], self.bns[bns_idx + 1],
		                 weight=vars[vars_idx + 2], bias= vars[vars_idx + 3],
		                 training=True, momentum=0.1)

		x = F.leaky_relu(x, 0.2)
		vars_idx += 4
		bns_idx += 2

		# [b, 512] => [b, 1024]
		x = F.linear(x, vars[vars_idx], vars[vars_idx + 1])
		x = F.batch_norm(x, self.bns[bns_idx + 0], self.bns[bns_idx + 1],
		                 weight=vars[vars_idx + 2], bias= vars[vars_idx + 3],
		                 training=True, momentum=0.1)

		x = F.leaky_relu(x, 0.2)
		vars_idx += 4
		bns_idx += 2

		# [b, 1024] => [b, 28*28]
		x = F.linear(x, vars[vars_idx], vars[vars_idx + 1])
		x = F.batch_norm(x, self.bns[bns_idx + 0], self.bns[bns_idx + 1],
		                 weight=vars[vars_idx + 2], bias= vars[vars_idx + 3],
		                 training=True, momentum=0.1)

		x = F.tanh(x)
		vars_idx += 4
		bns_idx += 2

		# reshape
		x = x.view(-1, 1, 28, 28)

		return x


class Discriminator:

	def __init__(self, n_class, device):

		# according to Pytorch w/b format, w = [out_dim, in_dim]
		# b = [out_dim]
		self.vars = [
			# [28*28, 512]
			torch.ones(512, 28 * 28, requires_grad=True, device=device),
			torch.zeros(512, requires_grad=True, device=device),
			# [512, 256]
			torch.ones(256, 512, requires_grad=True, device=device),
			torch.zeros(256, requires_grad=True, device=device),
			# [256, n]
			torch.ones(n_class, 256, requires_grad=True, device=device),
			torch.zeros(n_class, requires_grad=True, device=device)
		]



	def init_weight(self, vars=None):

		if vars is None:
			vars = self.vars

		vars_idx = 0

		weight, bias = vars[vars_idx], vars[vars_idx + 1]
		stdv = 1. / math.sqrt(weight.size(1))
		weight.uniform_(-stdv, stdv)
		bias.uniform_(-stdv, stdv)
		# nn.init.xavier_uniform_(weight)
		# bias.data.fill_(0.01)
		vars_idx += 2


		weight, bias = vars[vars_idx], vars[vars_idx + 1]
		stdv = 1. / math.sqrt(weight.size(1))
		weight.uniform_(-stdv, stdv)
		bias.uniform_(-stdv, stdv)
		# nn.init.xavier_uniform_(weight)
		# bias.data.fill_(0.01)
		vars_idx += 2


		weight, bias = vars[vars_idx], vars[vars_idx + 1]
		stdv = 1. / math.sqrt(weight.size(1))
		weight.uniform_(-stdv, stdv)
		bias.uniform_(-stdv, stdv)
		# nn.init.xavier_uniform_(weight)
		# bias.data.fill_(0.01)
		vars_idx += 2




	def forward(self, x, vars):
		"""

		:param x: [b, 1, 28, 28]
		:param vars:
		:return:
		"""

		vars_idx = 0

		# [b, 1/2, 28, 28]
		x = x.view(x.size(0), -1)

		# [b, 28*28] => [b, 512]
		x = F.linear(x, vars[vars_idx], vars[vars_idx + 1])
		# x = self.bn1(x)
		x = F.leaky_relu(x, 0.2)
		vars_idx += 2

		# [b, 512] => [b, 256]
		x = F.linear(x, vars[vars_idx], vars[vars_idx + 1])
		# x = self.bn2(x)
		x = F.leaky_relu(x, 0.2)
		vars_idx += 2


		# [b, 256] => [b, n_class]
		x = F.linear(x, vars[vars_idx], vars[vars_idx + 1])
		# x = self.bn3(x)
		# here follow by CrossEntroyLoss
		# x = F.leaky_relu(x, 0.2)
		x = F.sigmoid(x)
		vars_idx += 2

		return x



 

def main():

	from mnist_class import MNIST

	lr_d = 5
	lr_g = 2e-4
	imagesz = 28
	batchsz_d = 100
	batchsz_g = 100
	z_dim = 100
	n_class = 10

	device = torch.device('cuda')
	vis = visdom.Visdom()


	transform = transforms.Compose([transforms.Resize([imagesz, imagesz]),
	                                transforms.ToTensor(),
	                                transforms.Normalize(mean=(0.5,), std=(0.5,))])
	# use self defined MNIST
	mnist = MNIST('data/mnist', class_idx=range(n_class), train=True, download=True, transform=transform)

	db = DataLoader(mnist, batch_size=batchsz_g, shuffle=True)
	db_iter = iter(db)

	c_dist = torch.distributions.categorical.Categorical(probs=torch.tensor([1/n_class] * n_class))

	g = Generator(z_dim, n_class, device)
	d = Discriminator(n_class, device)
	# init is very important in case gradients==nan
	with torch.no_grad():
		g.init_weight()
		d.init_weight()
	g_vars, d_vars = g.vars, d.vars
	g_optim = optim.Adam(g_vars, lr=lr_g, betas=(0.5, 0.999))
	# d_optim = optim.Adam(d_vars, lr=2e-4, betas=(0.5, 0.999))

	# when using MSELoss, append F.sigmoid() at the end of D.
	# criteon = nn.CrossEntropyLoss(size_average=True).to(device)
	criteon = nn.MSELoss(size_average=True).to(device=device)
	criteon2 = nn.MSELoss(size_average=True).to(device=device)

	for epoch in range(1000000):

		# [b, z]
		z = torch.rand(batchsz_d, z_dim).to(device)
		# [b] => [b, 1]
		y_hat = c_dist.sample((batchsz_d,) ).unsqueeze(1)
		# [b, 1] => [b, n_class]
		y_hat_oh = torch.zeros(batchsz_d, n_class).scatter_(1, y_hat, 1)
		# [b, 1] => [b]
		y_hat, y_hat_oh = y_hat.squeeze(1).to(device), y_hat_oh.to(device)
		# print(y_hat, y_hat_oh)

		# [b, z+c] => [b, 1, 28, 28]
		x_hat = g.forward(z, y_hat_oh, g_vars)

		# 1. update D nets
		losses_d = []
		for i in range(10):
			pred = d.forward(x_hat, d_vars)
			l1 = criteon(pred, y_hat_oh)
			# MUST create_graph=True !!! to support 2nd derivate.
			d_grads = torch.autograd.grad(l1, d_vars, create_graph=True)
			d_vars = list(map(lambda p: p[0] - lr_d * p[1], zip(d_vars, d_grads)))

			losses_d += [l1.item()]

		# 2. update G nets
		# [b, 1, 28, 28], [b]
		try:
			x, y = next(db_iter)
		except StopIteration as err:
			db_iter = iter(db)
			x, y = next(db_iter)
		y_oh = torch.zeros(y.size(0), n_class).scatter_(1, y.unsqueeze(1).long(), 1)
		x, y, y_oh = x.to(device), y.to(device), y_oh.to(device)


		pred = d.forward(x, d_vars)
		l2 = criteon2(pred, y_oh)
		g_grads = torch.autograd.grad(l2, g_vars)
		# g_vars = list(map(lambda p: p[0] - lr_d * p[1], zip(g_vars, g_grads)))
		with torch.no_grad():
			for p, grad in zip(g_vars, g_grads):
				if p.grad is not None:
					p.grad.copy_(grad.detach())
				else:
					p.grad = grad.detach()
		# g_optim.zero_grad()
		# l2.backward(retain_graph=True)
			g_optim.step()


		# # [b, n_class] => [b]
		# pred = torch.argmax(pred, dim=1)
		# correct = torch.eq(pred, y).sum().float()
		# acc = correct.item() / np.prod(y.size()) # NOT y.sum()

		# print('>>>>memory check')
		# total_tensor, total_mem = 0, 0
		# for obj in gc.get_objects():
		# 	if torch.is_tensor(obj):
		# 		total_tensor += 1
		# 		total_mem += np.prod(obj.size())
		# 		print(obj.type(), obj.size())
		# print('<<<<', 'tensor:', total_tensor, 'mem:', total_mem//1024//1024)






if __name__ == '__main__':
	import argparse
	args = argparse.ArgumentParser()
	args.add_argument('-g', action='store_true', help='use gan to train')
	args = args.parse_args()

	main()