import torch
import torch.nn as nn
import torch.nn.init as init
class Generator(nn.Module):
def init(self, z_dim, M=8):
super().init()
self.M = M
self.linear = nn.Linear(z_dim, M * M * 256)
self.main = nn.Sequential(
nn.BatchNorm2d(256),
nn.ReLU(True),
nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1),
nn.BatchNorm2d(128),
nn.ReLU(True),
nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),
nn.BatchNorm2d(64),
nn.ReLU(True),
nn.Conv2d(64, 3, kernel_size=3, stride=1, padding=1),
nn.Tanh())
self.initialize()
def initialize(self):
for m in self.modules():
if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d, nn.Linear)):
init.normal_(m.weight, std=0.02)
init.zeros_(m.bias)
def forward(self, z, *args, **kwargs):
x = self.linear(z)
x = x.view(x.size(0), -1, self.M, self.M)
x = self.main(x)
return x
class Discriminator(nn.Module):
def init(self, M=32):
super(Discriminator, self).init()
self.main = nn.Sequential(
# 64
nn.Conv2d(3, 32, 5, 2, 2, bias=False),
nn.LeakyReLU(0.02, inplace=True),
# 32
nn.Conv2d(32, 64, 5, 2, 2, bias=False),
nn.LeakyReLU(0.02, inplace=True),
nn.BatchNorm2d(64),
nn.Dropout(p=0.5),
# 16
nn.Conv2d(64, 128, 5, 2, 2, bias=False),
nn.LeakyReLU(0.02, inplace=True),
nn.BatchNorm2d(128),
nn.Dropout(p=0.5),
# 8
nn.Conv2d(128, 10, 5, 2, 2, bias=False),
nn.ReLU(True)
# 4
)
self.linear = nn.Linear(M // 16 * M // 16 * 10, 1)
def forward(self, x):
x = self.main(x)
x = torch.flatten(x, start_dim=1)
x = self.linear(x)
return x
class Generator32(Generator):
def init(self, z_dim):
super().init(z_dim, M=8)
class Discriminator32(Discriminator):
def init(self):
super().init(M=32)
this works. anyway thank you