Retain graph with GANs



I’m trying to get a simple gan working on MNIST dataset. In order to create samples from the generator, I use a random seed:


class G(nn.Module):

def __init__(self): 

	self.l1 = nn.Linear(100,128)
	self.l2 = nn.Linear(128,784)

	self.adam = optim.Adam(self.parameters())

def forward(self, batch_size = 32): 

	tensor = (torch.rand(batch_size, 100)-0.5)*2.
	x = Variable(tensor.cuda())
	x = F.relu(self.l1(x))
	x = F.sigmoid(self.l2(x))

	return x 

def update(self, loss): 


And the main loop is:


g = G ()
d = D()
for epoch in range(epochs): 
    for x,y in train_set: 

	g_sample = g.forward(batch_size)
	real_sample = Variable(x.cuda()).view(batch_size,-1)

	real_pred, real_logits = d(real_sample)
	fake_pred, fake_logits = d(g_sample)

	g_loss = -torch.mean(torch.log(fake_pred))
	d_loss = -torch.mean(torch.log(real_pred) + torch.log(1.-fake_pred))
 #Testing and visualization stuff


However, when running the script, I get an error saying that I’m trying to backprop a second time through the generator graph. I don’t see why.

Could anyone help ?

Thanks !

(Swathikiran Sudhakaran) #2

You have to separate the two graphs(G and D) using detach. At the moment, network G also gets updated when calling d.update(d_loss). That’s why you are getting this error. You can see how it’s implemented in here: