class Generator(nn.Module):
def init(self, z_dim, M):
super(Generator, self).init()
self.z_dim = z_dim
self.main = nn.Sequential(
nn.ConvTranspose2d(self.z_dim, 256, M, 1, 0, bias=False), # 4, 4
nn.BatchNorm2d(256),
nn.ReLU(True),
nn.ConvTranspose2d(256, 128, 4, stride=2, padding=1, bias=False),
nn.BatchNorm2d(128),
nn.ReLU(True),
nn.ConvTranspose2d(128, 64, 4, stride=2, padding=1, bias=False),
nn.BatchNorm2d(64),
nn.ReLU(True),
nn.ConvTranspose2d(64, 3, 4, stride=2, padding=1, bias=False),
nn.Tanh()
)
def forward(self, z):
return self.main(z.view(-1, self.z_dim, 1, 1))
class Discriminator(nn.Module):
def init(self, M):
super(Discriminator, self).init()
self.main = nn.Sequential(
# 64
nn.Conv2d(3, 32, 5, 2, 2, bias=False),
nn.LeakyReLU(0.02, inplace=True),
# 32
nn.Conv2d(32, 64, 5, 2, 2, bias=False),
nn.LeakyReLU(0.02, inplace=True),
nn.BatchNorm2d(64),
nn.Dropout(p=0.5),
# 16
nn.Conv2d(64, 128, 5, 2, 2, bias=False),
nn.LeakyReLU(0.02, inplace=True),
nn.BatchNorm2d(128),
nn.Dropout(p=0.5),
# 8
nn.Conv2d(128, 10, 5, 2, 2, bias=False),
nn.ReLU(True)
# 4
)
self.linear = nn.Linear(16 // 16 * 16 //10 * 128, 40)
def forward(self, x):
x = self.main(x)
x = torch.flatten(x, start_dim=1)
x = self.linear(x)
return x
class Generator32(Generator):
def init(self, z_dim):
super().init(z_dim, M=2)
class Discriminator32(Discriminator):
def init(self):
super().init(M=32)