Hello community,
I am trying to implement a model-based GAN, where I try to predict simulation parameters from the generator and use the generator’s output to calculate simulations from a custom monte Carlo simulation function. I provide these simulated images to the discriminator with true images to calculate the loss.
I then try to update both generator and discriminator based on the losses calculated in the above stage.
The problem with this is, weights of the generator are not updating.
-
Is this approach feasible for PyTorch, where I have to break the graph on purpose?
-
Is there a better way to address this issue?
-
Is it possible if I define my simulation script as a custom torch layer?
I am attaching a simple example flow for reference.
class Discriminator(nn.Module):
def __init__(self, in_features):
super().__init__()
self.disc = nn.Sequential(
nn.Linear(in_features, 128),
nn.ReLU(),
nn.Linear(128, 1),
nn.Sigmoid(),
)
def forward(self, x):
return self.disc(x)
class Generator(nn.Module):
def __init__(self, n_params, scalar, test_transform=None):
super().__init__()
self.scalar = scalar
self.test_transform = test_transform
self.gen = nn.Sequential(
nn.Linear(n_params, 2 * n_params),
nn.ReLU(),
nn.Linear(2 * n_params, 2 * n_params),
nn.ReLU(),
nn.Linear(2 * n_params, n_params),
nn.Tanh(),
)
def forward(self, x):
p1 = self.gen(x)
params = self.scalar.inverse_transform(p1.clone().detach().numpy())
params = p1 + torch.tensor([100, 1.0e-12, 100, 1e-3, -1 * math.pi / 180])
# Generate simulation
im = simulate_image(params[0])
fake = self.test_transform(image=np.swapaxes(np.array(im), 1, -1))
fake = torch.from_numpy(fake["image"])
fake = fake.contiguous()
return fake
I also tried without converting my p1 Tensor to numpy and directly giving tensor to my simulation function. But that also didn’t work.