Is it possible to train a GAN that has a MLP generator and a GCN discriminator? The MLP generator is made using pytorch while the GCN discriminator is made using pytorch_geometric. I have included an example below where the generator generates a graph and the discriminator classifies if it is fake or not. The example I have provided is not the actual architecture I am trying to use, rather I have simplified it for the brevity and simplicity of the question.
import torch_geometric.data as data from torch_geometric.nn import GCNConv import torch.nn as nn import torch.nn.functional as F import torch class Generator(nn.Module): def __init__(self): super().__init__() self.fc1 = nn.Linear(12, 25) def forward(self, z): return F.sigmoid(self.fc1(z)) class Discriminator(nn.Module): def __init__(self): super().__init__() self.conv1 = GCNConv(1, 64) self.fc1 = nn.Linear(64, 1) def forward(self, g): x, edge_index = g.x, g.edge_index x = self.conv1(x, edge_index) x = F.relu(x) x = self.fc1(x) return F.sigmoid(x).mean() gen = Generator() disc = Discriminator() optimizer = torch.optim.Adam(gen.parameters(), lr=0.01) loss_function = nn.BCELoss() output = gen(torch.randn(12)).view((5,5)) # constructing COO connectivity matrix for Data object in pytorch geometric edge_index = [, ] for i in range(5): for j in range(5): if output[i][j]>=0.5: edge_index.append(i) edge_index.append(j) # all node_features are set to 1 just for the example d = data.Data(x=torch.Tensor([ for x in range(5)]), edge_index=torch.LongTensor(edge_index)) # classification of generated graph as real or fake by discriminator v = disc(d) # ground truth real = torch.FloatTensor(1, 1).fill_(1.0) # loss and optimization loss = loss_function(v, real) loss.backward() optimizer.step()
The example just walks through one graph being generated and loss being calculated, again for brevity. The discriminator training is not included because it seems to be working fine. However, after trying this out, the weights in the generator are not being updated. I assume this is because of the
grad not being transferred from the generator output to the
torch_geometric.data.Data object. Does anyone know how to back propogate for a GAN in a situation like this? Is it even possible with the combination of pytorch and pytorch_geometric? Thanks!