Hello!
I have trained a generator model as defined by:
class Downscale(nn.Module):
def __init__(self, in_size, out_size, normalize = True, dropout = 0.0):
super(Downscale, self).__init__()
model = [nn.Conv2d(
in_size,
out_size,
kernel_size = 4,
stride = 2,
padding = 1,
bias = False
)]
if normalize:
model.append(
nn.BatchNorm2d(out_size, 0.8)
)
model.append(
nn.LeakyReLU(0.2)
)
if dropout:
model.append(
nn.Dropout(dropout)
)
self.model = nn.Sequential(*model)
def forward(self, x):
return self.model(x)
##############################################################################
# #############################################################################
class Upscale(nn.Module):
def __init__(self, in_size, out_size, dropout = 0.0):
super(Upscale, self).__init__()
model =[
nn.ConvTranspose2d(
in_size,
out_size,
kernel_size = 4,
stride = 2,
padding = 1,
bias = False
),
nn.BatchNorm2d(out_size, 0.8),
nn.ReLU(inplace = True)
]
if dropout:
model.append(
nn.Dropout(dropout)
)
self.model = nn.Sequential(*model)
def forward(self, x, skip_input):
x = self.model(x)
out = torch.cat((x, skip_input), dim = 1)
return out
##############################################################################
# #############################################################################
class Generator(nn.Module):
def __init__(self, features_g, num_channels):
super(Generator, self).__init__()
self.features_g = features_g
self.num_channels = num_channels
def build(self):
# input: channels X 64 X 64
self.down1 = Downscale(
in_size = self.num_channels,
out_size = self.features_g,
normalize = False)
# input: features_g X 32 X 32
self.down2 = Downscale(
in_size = self.features_g,
out_size = self.features_g * 2)
# input: (features_g * 2) X 16 x 16
self.down3 = Downscale(
in_size = (self.features_g * 2 + self.num_channels),
out_size = self.features_g * 4,
dropout = 0.5
)
# input: (features_g * 4) X 8 X 8
self.down4 = Downscale(
in_size = self.features_g * 4,
out_size = self.features_g * 8,
dropout = 0.5
)
# input: (features_g * 8) X 4 X 4
self.down5 = Downscale(
in_size = self.features_g * 8,
out_size = self.features_g * 8,
dropout = 0.5
)
# input: (features_g * 8) X 2 X 2
self.down6 = Downscale(
in_size = self.features_g * 8,
out_size = self.features_g * 8,
dropout = 0.5
)
## state: (features_g * 8) X 1 X 1 ##
# input: (features_g * 8) X 1 X 1
self.up1 = Upscale(
in_size = self.features_g * 8,
out_size = self.features_g * 8,
dropout = 0.5
)
# input: (features_g * 8) X 2 X 2
self.up2 = Upscale(
in_size = self.features_g * 16,
out_size = self.features_g * 8,
dropout = 0.5
)
# input: (features_g * 8) X 4 X 4
self.up3 = Upscale(
in_size = self.features_g * 16,
out_size = self.features_g * 4,
dropout = 0.5
)
# input: (features_g * 4) X 8 X 8
self.up4 = Upscale(
in_size = self.features_g * 8,
out_size = self.features_g * 2
)
# input: (features_g * 2) X 16 X 16
self.up5 = Upscale(
in_size = (self.features_g * 4 + self.num_channels),
out_size = self.features_g
)
## state: features_g X 32 X 32 ##
final = [
nn.Upsample(scale_factor = 2),
# input: features_g X 64 X 64
nn.Conv2d(
in_channels = self.features_g * 2,
out_channels = self.num_channels,
kernel_size = 3,
stride = 1,
padding = 1
),
# input: num_channels X 64 X 64
nn.Tanh()
]
self.final = nn.Sequential(*final)
def forward(self, input, constraint_map):
d1 = self.down1(input)
d2 = self.down2(d1)
d2 = torch.cat((d2, constraint_map), dim = 1)
d3 = self.down3(d2)
d4 = self.down4(d3)
d5 = self.down5(d4)
d6 = self.down6(d5)
u1 = self.up1(d6, d5)
u2 = self.up2(u1, d4)
u3 = self.up3(u2, d3)
u4 = self.up4(u3, d2)
u5 = self.up5(u4, d1)
return self.final(u5)
def define_optim(self, learning_rate, beta1):
self.optimizer = optim.Adam(self.parameters(), lr = learning_rate, betas = (beta1, 0.999))
@staticmethod
def init_weights(layers):
classname = layers.__class__.__name__
if classname.find('Conv') != -1:
nn.init.normal_(layers.weight.data, 0.0, 0.02)
elif classname.find('BatchNorm') != -1:
nn.init.normal_(layers.weight.data, 1.0, 0.02)
nn.init.constant_(layers.bias.data, 0)
Once trained, I save the state with
torch.save(gen.state_dict(), "GENERATOR/gen.pt")
However, upon loading
model = Generator(features_g = 64, num_channels = 3)
for param in torch.load('GENERATOR/gen.pt'):
print(param)
print(model.state_dict())
for param in model.state_dict():
print(param)
toch.load
gives out the expected keys
, but model.state_dict()
returns OrderedDict()
.
Any advice?
Thanks!