Every loop i will do forward and then backward only once, why still the Error trigger:
Traceback (most recent call last):
File "toy_layer.py", line 635, in <module>
main()
File "toy_layer.py", line 557, in main
g_grads = torch.autograd.grad(l2, g_vars)
File "/hdd1/liangqu/conda/lib/python3.6/site-packages/torch/autograd/__init__.py", line 144, in grad
inputs, allow_unused)
RuntimeError: Trying to backward through the graph a second time, but the buffers have already been freed. Specify retain_graph=True when calling backward the first time.
Here is my source code:
import torch, math
import numpy as np
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
from torchvision import datasets, transforms
import visdom
from torch.nn import functional as F
import gc
class Generator:
def __init__(self, z_dim, c_dim, device):
# according to Pytorch w/b format, w = [out_dim, in_dim]
# b = [out_dim]
self.vars = [
# [z+c, 128]
torch.ones(128, z_dim + c_dim, requires_grad=True, device=device),
torch.zeros(128, requires_grad=True, device=device),
torch.rand(128, requires_grad=True, device=device),
torch.zeros(128, requires_grad=True, device=device),
# [128, 256]
torch.ones(256, 128, requires_grad=True, device=device),
torch.zeros(256, requires_grad=True, device=device),
torch.rand(256, requires_grad=True, device=device),
torch.zeros(256, requires_grad=True, device=device),
# [256, 512]
torch.ones(512, 256, requires_grad=True, device=device),
torch.zeros(512, requires_grad=True, device=device),
torch.rand(512, requires_grad=True, device=device),
torch.zeros(512, requires_grad=True, device=device),
# [512, 1024]
torch.ones(1024, 512, requires_grad=True, device=device),
torch.zeros(1024, requires_grad=True, device=device),
torch.rand(1024, requires_grad=True, device=device),
torch.zeros(1024, requires_grad=True, device=device),
# [1024, 28*28]
torch.ones(28 * 28, 1024, requires_grad=True, device=device),
torch.zeros(28 * 28, requires_grad=True, device=device),
torch.rand(28 * 28, requires_grad=True, device=device),
torch.zeros(28 * 28, requires_grad=True, device=device),
]
# moving mean and variance for normalization
# no gradients needed.
self.bns = [
torch.zeros(128, device=device),
torch.ones(128, device=device),
torch.zeros(256, device=device),
torch.ones(256, device=device),
torch.zeros(512, device=device),
torch.ones(512, device=device),
torch.zeros(1024, device=device),
torch.ones(1024, device=device),
torch.zeros(28 * 28, device=device),
torch.ones(28 * 28, device=device),
]
def init_weight(self, vars=None):
"""
init vars and self.bns
:param vars:
:return:
"""
if vars is None:
vars = self.vars
vars_idx = 0
weight, bias = vars[vars_idx], vars[vars_idx + 1]
stdv = 1. / math.sqrt(weight.size(1))
weight.uniform_(-stdv, stdv)
bias.uniform_(-stdv, stdv)
# nn.init.xavier_uniform_(weight)
# bias.data.fill_(0.01)
weight, bias = vars[vars_idx + 2], vars[vars_idx + 3]
weight.uniform_()
bias.zero_()
vars_idx += 4
weight, bias = vars[vars_idx], vars[vars_idx + 1]
stdv = 1. / math.sqrt(weight.size(1))
weight.uniform_(-stdv, stdv)
bias.uniform_(-stdv, stdv)
# nn.init.xavier_uniform_(weight)
# bias.data.fill_(0.01)
weight, bias = vars[vars_idx + 2], vars[vars_idx + 3]
weight.uniform_()
bias.zero_()
vars_idx += 4
weight, bias = vars[vars_idx], vars[vars_idx + 1]
stdv = 1. / math.sqrt(weight.size(1))
weight.uniform_(-stdv, stdv)
bias.uniform_(-stdv, stdv)
# nn.init.xavier_uniform_(weight)
# bias.data.fill_(0.01)
weight, bias = vars[vars_idx + 2], vars[vars_idx + 3]
weight.uniform_()
bias.zero_()
vars_idx += 4
weight, bias = vars[vars_idx], vars[vars_idx + 1]
stdv = 1. / math.sqrt(weight.size(1))
weight.uniform_(-stdv, stdv)
bias.uniform_(-stdv, stdv)
# nn.init.xavier_uniform_(weight)
# bias.data.fill_(0.01)
weight, bias = vars[vars_idx + 2], vars[vars_idx + 3]
weight.uniform_()
bias.zero_()
vars_idx += 4
weight, bias = vars[vars_idx], vars[vars_idx + 1]
stdv = 1. / math.sqrt(weight.size(1))
weight.uniform_(-stdv, stdv)
bias.uniform_(-stdv, stdv)
# nn.init.xavier_uniform_(weight)
# bias.data.fill_(0.01)
weight, bias = vars[vars_idx + 2], vars[vars_idx + 3]
weight.uniform_()
bias.zero_()
vars_idx += 4
# zero mean and one variance.
for i in range(len(self.bns) // 2 ):
# mean
self.bns[i].zero_()
# variance
self.bns[2 * i + 1].fill_(1)
def forward(self, z, c, vars):
"""
:param z:
:param c:
:param vars:
:return:
"""
vars_idx, bns_idx = 0, 0
# [b, z_dim] + [b, c_dim] => [b, new_dim]
x = torch.cat([z, c], dim=1)
# [b, z+c] => [b, 128]
x = F.linear(x, vars[vars_idx], vars[vars_idx + 1])
x = F.batch_norm(x, self.bns[bns_idx + 0], self.bns[bns_idx + 1],
weight=vars[vars_idx + 2], bias= vars[vars_idx + 3],
training=True, momentum=0.1)
x = F.leaky_relu(x, 0.2)
vars_idx += 4
bns_idx += 2
# [b, 128] => [b, 256]
x = F.linear(x, vars[vars_idx], vars[vars_idx + 1])
x = F.batch_norm(x, self.bns[bns_idx + 0], self.bns[bns_idx + 1],
weight=vars[vars_idx + 2], bias= vars[vars_idx + 3],
training=True, momentum=0.1)
x = F.leaky_relu(x, 0.2)
vars_idx += 4
bns_idx += 2
# [b, 256] => [b, 512]
x = F.linear(x, vars[vars_idx], vars[vars_idx + 1])
x = F.batch_norm(x, self.bns[bns_idx + 0], self.bns[bns_idx + 1],
weight=vars[vars_idx + 2], bias= vars[vars_idx + 3],
training=True, momentum=0.1)
x = F.leaky_relu(x, 0.2)
vars_idx += 4
bns_idx += 2
# [b, 512] => [b, 1024]
x = F.linear(x, vars[vars_idx], vars[vars_idx + 1])
x = F.batch_norm(x, self.bns[bns_idx + 0], self.bns[bns_idx + 1],
weight=vars[vars_idx + 2], bias= vars[vars_idx + 3],
training=True, momentum=0.1)
x = F.leaky_relu(x, 0.2)
vars_idx += 4
bns_idx += 2
# [b, 1024] => [b, 28*28]
x = F.linear(x, vars[vars_idx], vars[vars_idx + 1])
x = F.batch_norm(x, self.bns[bns_idx + 0], self.bns[bns_idx + 1],
weight=vars[vars_idx + 2], bias= vars[vars_idx + 3],
training=True, momentum=0.1)
x = F.tanh(x)
vars_idx += 4
bns_idx += 2
# reshape
x = x.view(-1, 1, 28, 28)
return x
class Discriminator:
def __init__(self, n_class, device):
# according to Pytorch w/b format, w = [out_dim, in_dim]
# b = [out_dim]
self.vars = [
# [28*28, 512]
torch.ones(512, 28 * 28, requires_grad=True, device=device),
torch.zeros(512, requires_grad=True, device=device),
# [512, 256]
torch.ones(256, 512, requires_grad=True, device=device),
torch.zeros(256, requires_grad=True, device=device),
# [256, n]
torch.ones(n_class, 256, requires_grad=True, device=device),
torch.zeros(n_class, requires_grad=True, device=device)
]
def init_weight(self, vars=None):
if vars is None:
vars = self.vars
vars_idx = 0
weight, bias = vars[vars_idx], vars[vars_idx + 1]
stdv = 1. / math.sqrt(weight.size(1))
weight.uniform_(-stdv, stdv)
bias.uniform_(-stdv, stdv)
# nn.init.xavier_uniform_(weight)
# bias.data.fill_(0.01)
vars_idx += 2
weight, bias = vars[vars_idx], vars[vars_idx + 1]
stdv = 1. / math.sqrt(weight.size(1))
weight.uniform_(-stdv, stdv)
bias.uniform_(-stdv, stdv)
# nn.init.xavier_uniform_(weight)
# bias.data.fill_(0.01)
vars_idx += 2
weight, bias = vars[vars_idx], vars[vars_idx + 1]
stdv = 1. / math.sqrt(weight.size(1))
weight.uniform_(-stdv, stdv)
bias.uniform_(-stdv, stdv)
# nn.init.xavier_uniform_(weight)
# bias.data.fill_(0.01)
vars_idx += 2
def forward(self, x, vars):
"""
:param x: [b, 1, 28, 28]
:param vars:
:return:
"""
vars_idx = 0
# [b, 1/2, 28, 28]
x = x.view(x.size(0), -1)
# [b, 28*28] => [b, 512]
x = F.linear(x, vars[vars_idx], vars[vars_idx + 1])
# x = self.bn1(x)
x = F.leaky_relu(x, 0.2)
vars_idx += 2
# [b, 512] => [b, 256]
x = F.linear(x, vars[vars_idx], vars[vars_idx + 1])
# x = self.bn2(x)
x = F.leaky_relu(x, 0.2)
vars_idx += 2
# [b, 256] => [b, n_class]
x = F.linear(x, vars[vars_idx], vars[vars_idx + 1])
# x = self.bn3(x)
# here follow by CrossEntroyLoss
# x = F.leaky_relu(x, 0.2)
x = F.sigmoid(x)
vars_idx += 2
return x
def main():
from mnist_class import MNIST
lr_d = 5
lr_g = 2e-4
imagesz = 28
batchsz_d = 100
batchsz_g = 100
z_dim = 100
n_class = 10
device = torch.device('cuda')
vis = visdom.Visdom()
transform = transforms.Compose([transforms.Resize([imagesz, imagesz]),
transforms.ToTensor(),
transforms.Normalize(mean=(0.5,), std=(0.5,))])
# use self defined MNIST
mnist = MNIST('data/mnist', class_idx=range(n_class), train=True, download=True, transform=transform)
db = DataLoader(mnist, batch_size=batchsz_g, shuffle=True)
db_iter = iter(db)
c_dist = torch.distributions.categorical.Categorical(probs=torch.tensor([1/n_class] * n_class))
g = Generator(z_dim, n_class, device)
d = Discriminator(n_class, device)
# init is very important in case gradients==nan
with torch.no_grad():
g.init_weight()
d.init_weight()
g_vars, d_vars = g.vars, d.vars
g_optim = optim.Adam(g_vars, lr=lr_g, betas=(0.5, 0.999))
# d_optim = optim.Adam(d_vars, lr=2e-4, betas=(0.5, 0.999))
# when using MSELoss, append F.sigmoid() at the end of D.
# criteon = nn.CrossEntropyLoss(size_average=True).to(device)
criteon = nn.MSELoss(size_average=True).to(device=device)
criteon2 = nn.MSELoss(size_average=True).to(device=device)
for epoch in range(1000000):
# [b, z]
z = torch.rand(batchsz_d, z_dim).to(device)
# [b] => [b, 1]
y_hat = c_dist.sample((batchsz_d,) ).unsqueeze(1)
# [b, 1] => [b, n_class]
y_hat_oh = torch.zeros(batchsz_d, n_class).scatter_(1, y_hat, 1)
# [b, 1] => [b]
y_hat, y_hat_oh = y_hat.squeeze(1).to(device), y_hat_oh.to(device)
# print(y_hat, y_hat_oh)
# [b, z+c] => [b, 1, 28, 28]
x_hat = g.forward(z, y_hat_oh, g_vars)
# 1. update D nets
losses_d = []
for i in range(10):
pred = d.forward(x_hat, d_vars)
l1 = criteon(pred, y_hat_oh)
# MUST create_graph=True !!! to support 2nd derivate.
d_grads = torch.autograd.grad(l1, d_vars, create_graph=True)
d_vars = list(map(lambda p: p[0] - lr_d * p[1], zip(d_vars, d_grads)))
losses_d += [l1.item()]
# 2. update G nets
# [b, 1, 28, 28], [b]
try:
x, y = next(db_iter)
except StopIteration as err:
db_iter = iter(db)
x, y = next(db_iter)
y_oh = torch.zeros(y.size(0), n_class).scatter_(1, y.unsqueeze(1).long(), 1)
x, y, y_oh = x.to(device), y.to(device), y_oh.to(device)
pred = d.forward(x, d_vars)
l2 = criteon2(pred, y_oh)
g_grads = torch.autograd.grad(l2, g_vars)
# g_vars = list(map(lambda p: p[0] - lr_d * p[1], zip(g_vars, g_grads)))
with torch.no_grad():
for p, grad in zip(g_vars, g_grads):
if p.grad is not None:
p.grad.copy_(grad.detach())
else:
p.grad = grad.detach()
# g_optim.zero_grad()
# l2.backward(retain_graph=True)
g_optim.step()
# # [b, n_class] => [b]
# pred = torch.argmax(pred, dim=1)
# correct = torch.eq(pred, y).sum().float()
# acc = correct.item() / np.prod(y.size()) # NOT y.sum()
# print('>>>>memory check')
# total_tensor, total_mem = 0, 0
# for obj in gc.get_objects():
# if torch.is_tensor(obj):
# total_tensor += 1
# total_mem += np.prod(obj.size())
# print(obj.type(), obj.size())
# print('<<<<', 'tensor:', total_tensor, 'mem:', total_mem//1024//1024)
if __name__ == '__main__':
import argparse
args = argparse.ArgumentParser()
args.add_argument('-g', action='store_true', help='use gan to train')
args = args.parse_args()
main()