I have created an autoencoder whose code is given below. But I try to check the parameters of the model, it does not show all the parameters of encoder and decoders.
class encoder(nn.Module):
def __init__(self,
input_dim,
encoder_conv_filters,
encoder_conv_kernel_size,
encoder_conv_strides,
z_dim,
use_batch_norm = False,
use_dropout = False):
super(encoder, self).__init__()
self.input_dim = input_dim
self.encoder_conv_filters = encoder_conv_filters
self.encoder_conv_kernel_size = encoder_conv_kernel_size
self.encoder_conv_strides = encoder_conv_strides
self.z_dim = z_dim
self.use_batch_norm = use_batch_norm
self.use_dropout = use_dropout
self.n_layers_encoder = len(self.encoder_conv_filters)
self.tensor_shapes = []
self.conv2d = []
self.batchNorm = []
input_channels = self.input_dim[0]
print(self.input_dim)
for i in range(self.n_layers_encoder):
'''if i > 0:
padding = 1
else:
padding = (2, 1)'''
padding = 1
print(input_channels)
self.conv2d.append(nn.Conv2d(in_channels = input_channels,
out_channels = self.encoder_conv_filters[i],
kernel_size = self.encoder_conv_kernel_size[i],
stride = self.encoder_conv_strides[i],
padding = padding))
input_channels = self.encoder_conv_filters[i]
self.batchNorm.append(nn.BatchNorm2d(self.encoder_conv_filters[i], affine = True))
def forward(self, x):
input_channels = x.shape[1]
print("x shape:", x.shape)
for i in range(self.n_layers_encoder):
x = self.conv2d[i](x)
x = nn.LeakyReLU()(x)
if self.use_batch_norm:
x = self.batchNorm[i](x)
if self.use_dropout:
x = nn.Dropout(p = 0.25)(x)
print("x shape:", x.shape)
self.tensor_shapes.append(x.size())
self.shape_before_reshape = x.shape
x = x.view(x.shape[0], -1)
print("x shape after reshape:", x.shape)
x = nn.Linear(x.shape[1], self.z_dim)(x)
print("x shape:", x.shape)
#x = nn.MaxPool2d(2, stride=2)(x) # b, 16, 5, 5
print('-'*100)
return x, self.shape_before_reshape, self.tensor_shapes
class decoder(nn.Module):
def __init__(self,
decoder_conv_t_filters,
decoder_conv_t_kernel_size,
decoder_conv_t_strides,
fitness_linear,
z_dim,
shape_before_reshape,
tensor_shapes,
use_batch_norm = False,
use_dropout = False):
super(decoder, self).__init__()
self.decoder_conv_t_filters = decoder_conv_t_filters
self.decoder_conv_t_kernel_size = decoder_conv_t_kernel_size
self.decoder_conv_t_strides = decoder_conv_t_strides
self.z_dim = z_dim
self.fitness_linear = fitness_linear
self.shape_before_reshape = shape_before_reshape
self.tensor_shapes = tensor_shapes[::-1]
self.tensor_shapes = self.tensor_shapes[1:]
self.use_batch_norm = use_batch_norm
self.use_dropout = use_dropout
self.n_layers_decoder = len(self.decoder_conv_t_filters)
self.convTranspose = []
self.batchNorm = []
self.linear1 = nn.Linear(self.z_dim, np.prod(self.shape_before_reshape[1:]))
input_channels = shape_before_reshape[1]
for i in range(self.n_layers_decoder):
self.convTranspose.append(nn.ConvTranspose2d(in_channels = input_channels,
out_channels = self.decoder_conv_t_filters[i],
kernel_size = self.decoder_conv_t_kernel_size[i],
stride = self.decoder_conv_t_strides[i],
padding=1))
input_channels = self.decoder_conv_t_filters[i]
self.batchNorm.append(nn.BatchNorm2d(self.decoder_conv_t_filters[i], affine = True))
def forward(self, x):
print('='*100)
print(self.tensor_shapes)
print("\ndecoder: x shape:", x.shape)
x = self.linear1(x)
fitness = x
print("decoder: x shape after nn.Linear:", x.shape)
print("shape_before_reshape:", self.shape_before_reshape)
x = x.view(-1, self.shape_before_reshape[1], self.shape_before_reshape[2], self.shape_before_reshape[3])
print("decoder: x shape after reshape:", x.shape)
input_channels = x.shape[1]
for i in range(self.n_layers_decoder):
if i == self.n_layers_decoder - 1:
x = self.convTranspose[i](x)
else:
x = self.convTranspose[i](x, output_size = self.tensor_shapes[i])
if i < self.n_layers_decoder - 1:
x = nn.LeakyReLU()(x)
if self.use_batch_norm:
x = self.batchNorm[i](x)
if self.use_dropout:
x = nn.Dropout(p = 0.25)(x)
else:
x = nn.Sigmoid()(x)
print("i: {}, x shape:{}".format(i, x.shape))
# Fitness estimation
print("fitness:{}".format(fitness.shape))
for output_shape in self.fitness_linear:
fitness = nn.Linear(fitness.shape[1], output_shape)(fitness)
fitness = nn.LeakyReLU()(fitness)
if self.use_dropout:
fitness = nn.Dropout(p = 0.25)(fitness)
print("fitness:{}".format(fitness.shape))
return x, fitness
class autoencoder(nn.Module):
def __init__(self,
input_dim,
encoder_conv_filters,
encoder_conv_kernel_size,
encoder_conv_strides,
decoder_conv_t_filters,
decoder_conv_t_kernel_size,
decoder_conv_t_strides,
fitness_linear,
z_dim,
use_batch_norm = False,
use_dropout = False):
super(autoencoder, self).__init__()
self.encoder = encoder( input_dim,
encoder_conv_filters,
encoder_conv_kernel_size,
encoder_conv_strides,
z_dim,
use_batch_norm,
use_dropout)
tmp_input = torch.randn(1, input_dim[0], input_dim[1], input_dim[2])
_, shape_before_reshape, tensor_shapes = self.encoder(tmp_input)
self.decoder = decoder( decoder_conv_t_filters,
decoder_conv_t_kernel_size,
decoder_conv_t_strides,
fitness_linear,
z_dim,
shape_before_reshape,
tensor_shapes,
use_batch_norm,
use_dropout)
def forward(self, x):
x, shape_before_reshape, tensor_shapes = self.encoder(x)
x = self.decoder(x)
return x
>>> print(model)
autoencoder(
(encoder): encoder()
(decoder): decoder(
(linear1): Linear(in_features=2, out_features=512, bias=True)
)
)
>>> print(model.state_dict().keys())
odict_keys(['decoder.linear1.weight', 'decoder.linear1.bias'])