Retain graph with GANs

Hello,

I’m trying to get a simple gan working on MNIST dataset. In order to create samples from the generator, I use a random seed:

`

class G(nn.Module):

def __init__(self): 

	nn.Module.__init__(self)
	self.l1 = nn.Linear(100,128)
	self.l2 = nn.Linear(128,784)

	self.adam = optim.Adam(self.parameters())

def forward(self, batch_size = 32): 

	tensor = (torch.rand(batch_size, 100)-0.5)*2.
	x = Variable(tensor.cuda())
	x = F.relu(self.l1(x))
	x = F.sigmoid(self.l2(x))

	return x 

def update(self, loss): 

	self.adam.zero_grad()
	loss.backward()
	self.adam.step()

`
And the main loop is:

`

g = G ()
d = D()
for epoch in range(epochs): 
    for x,y in train_set: 

	g_sample = g.forward(batch_size)
	real_sample = Variable(x.cuda()).view(batch_size,-1)


	real_pred, real_logits = d(real_sample)
	fake_pred, fake_logits = d(g_sample)

	g_loss = -torch.mean(torch.log(fake_pred))
	d_loss = -torch.mean(torch.log(real_pred) + torch.log(1.-fake_pred))
	
	d.update(d_loss)
	g.update(g_loss)
 #Testing and visualization stuff

`

However, when running the script, I get an error saying that I’m trying to backprop a second time through the generator graph. I don’t see why.

Could anyone help ?

Thanks !

You have to separate the two graphs(G and D) using detach. At the moment, network G also gets updated when calling d.update(d_loss). That’s why you are getting this error. You can see how it’s implemented in here:
https://github.com/pytorch/examples/blob/e0d33a69bec3eb4096c265451dbb85975eb961ea/dcgan/main.py#L225-L251

2 Likes