I am training a WGAN-GP on a 123 x 123 B&W Image. My model is training but I get very negative crit_loss values as far as Gen_loss = 6, and Crit_Loss = -106 at 400 epochs
So far, I have taken the following suggestions into consideration: made my critic symmetric to my generator, drawn from a larger z-dim, and added more layers to the generator. I would appreciate any other recommendations or adjustments that may alleviate this issue.
Generator Model:
class Generator(nn.Module):
def __init__(self, z_dim=10, im_chan=1, hidden_dim=64):
super(Generator, self).__init__()
self.z_dim = z_dim
# Build the neural network
self.gen = nn.Sequential(
self.make_gen_block(z_dim, hidden_dim * 8),
self.make_gen_block(hidden_dim * 8, hidden_dim * 4, kernel_size=5, stride=2, padding = 1),
self.make_gen_block(hidden_dim * 4, hidden_dim * 2, kernel_size=5, stride=2, padding = 1),
self.make_gen_block(hidden_dim * 2, hidden_dim),
self.make_gen_block(hidden_dim, im_chan, kernel_size=3, final_layer=True),
)
def make_gen_block(self, input_channels, output_channels, kernel_size=5, stride=2, padding = 1, final_layer=False):
if not final_layer:
return nn.Sequential(
nn.ConvTranspose2d(input_channels, output_channels, kernel_size, stride),
nn.BatchNorm2d(output_channels),
nn.ReLU(inplace=True),
)
else:
return nn.Sequential(
nn.ConvTranspose2d(input_channels, output_channels, kernel_size, stride),
nn.Tanh(),
)
def forward(self, noise):
x = noise.view(len(noise), self.z_dim, 1, 1)
return self.gen(x)
Critic Model:
class Critic(nn.Module):
def __init__(self, im_chan=1, hidden_dim=64):
super(Critic, self).__init__()
self.crit = nn.Sequential(
self.make_crit_block(im_chan, hidden_dim),
# self.make_crit_block(hidden_dim, hidden_dim * 8, kernel_size=5, stride=2, padding = 1),
self.make_crit_block(hidden_dim, hidden_dim * 4, kernel_size=5, stride=2, padding = 1),
self.make_crit_block(hidden_dim * 4, hidden_dim * 2, kernel_size=5, stride=2, padding = 1),
self.make_crit_block(hidden_dim * 2, 1, kernel_size=3, final_layer=True),
)
def make_crit_block(self, input_channels, output_channels, kernel_size=4, stride=2, padding = 1, final_layer=False):
if not final_layer:
return nn.Sequential(
nn.Conv2d(input_channels, output_channels, kernel_size, stride),
nn.BatchNorm2d(output_channels),
nn.LeakyReLU(0.2, inplace=True),
)
else:
return nn.Sequential(
nn.Conv2d(input_channels, output_channels, kernel_size, stride),
)
def forward(self, image):
crit_pred = self.crit(image)
return crit_pred.view(len(crit_pred), -1)