I am trying to implement an autoencoder and for that I’m using the unet architecture to train cifar-10 data. Since there is no pre-defined architecture, I’m writing one of my own.
Here is my code.
class Block(nn.Module):
def init(self, in_ch, out_ch):
super().init()
self.conv1 = nn.Conv2d(in_ch, out_ch, 3)
self.relu = nn.ReLU()
self.conv2 = nn.Conv2d(out_ch, out_ch, 3)
def forward(self, x):
x = self.conv1(x)
x = self.relu(x)
x = self.conv2(x)
x = self.relu(x)
return x
# return self.relu(self.conv2(self.relu(self.conv1(x))))
class Encoder(nn.Module):
def init(self, chs=(3,32,64,128,256)):
super().init()
self.enc_blocks = nn.ModuleList([Block(chs[i], chs[i+1]) for i in range(len(chs)-1)])
self.pool = nn.MaxPool2d(2)
self.final_block = Block(256, 256)
def forward(self, x):
ftrs = []
for block in self.enc_blocks:
x = block(x)
ftrs.append(x)
x = self.pool(x)
x = self.final_block(x)
ftrs.append(x)
return ftrs
class Decoder(nn.Module):
def init(self, chs=(256, 256, 128, 64, 32)):
super().init()
self.chs = chs
self.upconvs = nn.ModuleList([nn.ConvTranspose2d(chs[i], chs[i+1], 2, 2) for i in range(len(chs)-1)])
self.dec_blocks = nn.ModuleList([Block(chs[i], chs[i+1]) for i in range(len(chs)-1)])
self.batch = nn.ModuleList([BatchNorm2d(num_features=chs[i]) for i in range(len(chs) - 1)])
def forward(self, x, encoder_features):
for i in range(len(self.chs)-1):
x = self.upconvs[i](x)
print(x.shape)
enc_ftrs = self.crop(encoder_features[i], x)
x = torch.cat([x, enc_ftrs], dim=1)
print(x.shape)
# x = self.batch[i](x)
# x = F.relu(x)
x = self.dec_blocks[i](x)
return x
def crop(self, enc_ftrs, x):
_, _, H, W = x.shape
enc_ftrs = torchvision.transforms.CenterCrop([H, W])(enc_ftrs)
return enc_ftrs
class UNet(nn.Module):
def init(self, enc_chs=(3,32,64,128,256), dec_chs=(256, 256, 128, 64, 32), num_class=3, retain_dim=False, out_sz=(572,572)):
super().init()
self.encoder = Encoder(enc_chs)
self.decoder = Decoder(dec_chs)
self.head = nn.ConvTranspose2d(dec_chs[-1], num_class, kernel_size = 3, padding=(1,1))
self.retain_dim = retain_dim
def forward(self, x):
enc_ftrs = self.encoder(x)
out = self.decoder(enc_ftrs[::-1][0], enc_ftrs[::-1][1:])
out = self.head(out)
if self.retain_dim:
out = F.interpolate(out, out_sz)
return out
I’m getting an error while training - “Given groups=1, weight of size [256, 256, 3, 3], expected input[64, 512, 12, 12] to have 256 channels, but got 512 channels instead”.
I’m taking some help from this blog - U-Net: A PyTorch Implementation in 60 lines of Code | Committed towards better future (amaarora.github.io)
Any corrections to the code or how to solve this error will be appreciated. If there is something I am doing wrong in the code, pls help me figure out. Thanks.