hello guys I am new to this field and I need some help please
I would like to modify this code so it code generate image size of 256x256
this is the part of code that I want to modify and it is for image size of 128X128
class LayoutEncoder(nn.Module):
def init(self, conv_dim=64, z_dim=8, embedding_dim=64, class_num=10, resi_num=6, clstm_layers=3):
super(LayoutEncoder, self).init()
self.activation = nn.ReLU(inplace=True)
self.embedding = nn.Embedding(class_num, embedding_dim)
if clstm_layers == 1:
self.clstm = LayoutConvLSTM(8, 512, [64], (5, 5))
elif clstm_layers == 2:
self.clstm = LayoutConvLSTM(8, 512, [128, 64], (5, 5))
elif clstm_layers == 3:
self.clstm = LayoutConvLSTM(8, 512, [128, 64, 64], (5, 5))
layers = []
# Bottleneck layers.
for i in range(resi_num):
layers.append(ResidualBlock(dim_in=64, dim_out=64))
self.residual = nn.Sequential(*layers)
# (emb+z, 64, 64) -> (64, 64, 64)
self.c1 = nn.Conv2d(embedding_dim + z_dim, conv_dim, kernel_size=1, stride=1, padding=1, bias=False)
self.bn1 = nn.BatchNorm2d(conv_dim) if class_num == 0 else ConditionalBatchNorm2d(conv_dim, class_num)
# (64, 64, 64) -> (128, 32, 32)
self.c2 = nn.Conv2d(conv_dim, conv_dim * 2, kernel_size=4, stride=2, padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(conv_dim * 2) if class_num == 0 else ConditionalBatchNorm2d(conv_dim * 2, class_num)
# (128, 32, 32) -> (256, 16, 16)
self.c3 = nn.Conv2d(conv_dim * 2, conv_dim * 4, kernel_size=4, stride=2, padding=1, bias=False)
self.bn3 = nn.BatchNorm2d(conv_dim * 4) if class_num == 0 else ConditionalBatchNorm2d(conv_dim * 4, class_num)
# (256, 16, 16) -> (512, 8, 8)
self.c4 = nn.Conv2d(conv_dim * 4, conv_dim * 8, kernel_size=4, stride=2, padding=1, bias=False)
self.bn4 = nn.BatchNorm2d(conv_dim * 8) if class_num == 0 else ConditionalBatchNorm2d(conv_dim * 8, class_num)
# (512, 16, 16) -> (512, 8, 8)
self.pool = nn.AdaptiveAvgPool2d(8)
def forward(self, objs, masks, obj_to_img, z):
# prepare mask fm
embeddings = self.embedding(objs)
embeddings_z = torch.cat((embeddings, z), dim=1)
h = embeddings_z.view(embeddings_z.size(0), embeddings_z.size(1), 1, 1) * masks
# downsample layout
h = self.c1(h)
h = self.bn1(h, objs)
h = self.activation(h)
h = self.c2(h)
h = self.bn2(h, objs)
h = self.activation(h)
h = self.c3(h)
h = self.bn3(h, objs)
h = self.activation(h)
h = self.c4(h)
h = self.bn4(h, objs)
h = self.pool(h)
# clstm fusion (O, 512, 8, 8) -> (n, 64, 8, 8)
h = self.clstm(h, obj_to_img)
# residual block
h = self.residual(h)
return h
class Decoder(nn.Module):
def init(self, conv_dim=64):
super(Decoder, self).init()
self.activation = nn.ReLU(inplace=True)
# (64, 8, 8) -> (256, 8, 8)
self.c0 = nn.Conv2d(conv_dim, conv_dim * 4, kernel_size=3, stride=1, padding=1, bias=False)
self.bn0 = nn.BatchNorm2d(conv_dim * 4)
# (256, 8, 8) -> (256, 16, 16)
self.dc1 = nn.ConvTranspose2d(conv_dim * 4, conv_dim * 4, kernel_size=4, stride=2, padding=1, bias=False)
self.bn1 = nn.BatchNorm2d(conv_dim * 4)
# (256, 16, 16) -> (128, 32, 32)
self.dc2 = nn.ConvTranspose2d(conv_dim * 4, conv_dim * 2, kernel_size=4, stride=2, padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(conv_dim * 2)
# (128, 32, 32) -> (64, 64, 64)
self.dc3 = nn.ConvTranspose2d(conv_dim * 2, conv_dim * 1, kernel_size=4, stride=2, padding=1, bias=False)
self.bn3 = nn.BatchNorm2d(conv_dim * 1)
# (64, 64, 64) -> (3, 64, 64)
self.c4 = nn.Conv2d(conv_dim * 1, 3, kernel_size=7, stride=1, padding=3, bias=True)
self.c5 = nn.Conv2d(3, conv_dim * 2, kernel_size=7, stride=1, padding=3, bias=False)
self.bn4 = nn.BatchNorm2d(conv_dim * 2)
self.c6 = nn.Conv2d(conv_dim * 2, conv_dim * 2, kernel_size=5, stride=1, padding=2, bias=False)
self.bn5 = nn.BatchNorm2d(conv_dim * 2)
self.c7 = nn.Conv2d(conv_dim * 2, 3, kernel_size=7, stride=1, padding=3, bias=True)
def forward(self, hidden):
h = hidden
h = self.c0(h)
h = self.bn0(h)
h = self.activation(h)
h = self.dc1(h)
h = self.bn1(h)
h = self.activation(h)
h = self.dc2(h)
h = self.bn2(h)
h = self.activation(h)
h = self.dc3(h)
h = self.bn3(h)
h = self.activation(h)
h_64 = self.c4(h)
# upsampling 2 x
upsample = F.interpolate(h_64, scale_factor=2, mode='nearest')
h = self.c5(upsample)
h = self.bn4(h)
h = self.activation(h)
h = self.c6(h)
h = self.bn5(h)
h = self.activation(h)
h = self.c7(h)
return h