I am building a discriminator for a conditional GAN which consists of 2 components:
-
self.main_module
which reduces an image from 256x256 to 64x64 (this was previously the only module in a patch-GAN) -
self.head
which takes a flattened output from the above module, concats it with one-hot labels and uses linear layers to get a single value output.
The problem is that after a certain number of steps, self.head
returns NaN values while none of the inputs contain NaN values.
Around 400 classes are used in the one-hot labels. Could this be too sparse? How can this be fixed?
Full code:
class Discriminator(nn.Module):
def __init__(self, in_channels=3, n_classes=0):
super().__init__()
self.main_module = nn.Sequential(
nn.Conv2d(in_channels=in_channels, out_channels=32, kernel_size=3, stride=1, padding=3//2),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, stride=2, padding=3//2),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=3//2),
nn.BatchNorm2d(128, affine=True),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=2, padding=3//2),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, stride=1, padding=3//2),
nn.BatchNorm2d(256, affine=True),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=3//2),
nn.BatchNorm2d(256, affine=True),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(in_channels=256, out_channels=1, kernel_size=3, stride=1, padding=3//2),
# Out is 1x64x64
)
self.head = nn.Sequential(
nn.Linear(64*64 + n_classes, 1024),
nn.BatchNorm1d(1024, affine=True),
nn.LeakyReLU(0.2, inplace=True),
nn.Linear(1024, 256),
nn.BatchNorm1d(256, affine=True),
nn.LeakyReLU(0.2, inplace=True),
nn.Linear(256, 1),
)
def forward(self, images, labels):
x = self.main_module(images)
x = torch.cat([x.view(x.shape[0], -1), labels], dim=1)
x = self.head(x)
return x