UNET architecture in Pytorch

I am trying to implement an autoencoder and for that I’m using the unet architecture to train cifar-10 data. Since there is no pre-defined architecture, I’m writing one of my own.
Here is my code.

class Block(nn.Module):
def init(self, in_ch, out_ch):
super().init()
self.conv1 = nn.Conv2d(in_ch, out_ch, 3)
self.relu = nn.ReLU()
self.conv2 = nn.Conv2d(out_ch, out_ch, 3)

def forward(self, x):
  x = self.conv1(x)
  x = self.relu(x)
  x = self.conv2(x)
  x = self.relu(x)
  return x
    # return self.relu(self.conv2(self.relu(self.conv1(x))))

class Encoder(nn.Module):
def init(self, chs=(3,32,64,128,256)):
super().init()
self.enc_blocks = nn.ModuleList([Block(chs[i], chs[i+1]) for i in range(len(chs)-1)])
self.pool = nn.MaxPool2d(2)
self.final_block = Block(256, 256)

def forward(self, x):
    ftrs = []
    for block in self.enc_blocks:
        x = block(x)
        ftrs.append(x)
        x = self.pool(x)
    x = self.final_block(x)
    ftrs.append(x)
    return ftrs

class Decoder(nn.Module):
def init(self, chs=(256, 256, 128, 64, 32)):
super().init()
self.chs = chs
self.upconvs = nn.ModuleList([nn.ConvTranspose2d(chs[i], chs[i+1], 2, 2) for i in range(len(chs)-1)])
self.dec_blocks = nn.ModuleList([Block(chs[i], chs[i+1]) for i in range(len(chs)-1)])
self.batch = nn.ModuleList([BatchNorm2d(num_features=chs[i]) for i in range(len(chs) - 1)])

def forward(self, x, encoder_features):
    for i in range(len(self.chs)-1):
        x        = self.upconvs[i](x)
        print(x.shape)
        enc_ftrs = self.crop(encoder_features[i], x)
        x        = torch.cat([x, enc_ftrs], dim=1)
        print(x.shape)
        # x = self.batch[i](x)
        # x = F.relu(x)
        x = self.dec_blocks[i](x)
    return x

def crop(self, enc_ftrs, x):
    _, _, H, W = x.shape
    enc_ftrs   = torchvision.transforms.CenterCrop([H, W])(enc_ftrs)
    return enc_ftrs

class UNet(nn.Module):
def init(self, enc_chs=(3,32,64,128,256), dec_chs=(256, 256, 128, 64, 32), num_class=3, retain_dim=False, out_sz=(572,572)):
super().init()
self.encoder = Encoder(enc_chs)
self.decoder = Decoder(dec_chs)
self.head = nn.ConvTranspose2d(dec_chs[-1], num_class, kernel_size = 3, padding=(1,1))
self.retain_dim = retain_dim

def forward(self, x):
    enc_ftrs = self.encoder(x)
    out      = self.decoder(enc_ftrs[::-1][0], enc_ftrs[::-1][1:])
    out      = self.head(out)
    if self.retain_dim:
        out = F.interpolate(out, out_sz)
    return out

I’m getting an error while training - “Given groups=1, weight of size [256, 256, 3, 3], expected input[64, 512, 12, 12] to have 256 channels, but got 512 channels instead”.

I’m taking some help from this blog - U-Net: A PyTorch Implementation in 60 lines of Code | Committed towards better future (amaarora.github.io)

Any corrections to the code or how to solve this error will be appreciated. If there is something I am doing wrong in the code, pls help me figure out. Thanks.

Not sure why you are doing this when there are plenty of people who have implemented UNet in pytorch. Just google it. There are a few good youtube videos that take you through coding it from scratch.

Thanks for the reply. I searched through the internet and all the implementations of the unet are for 572572 size images. However, I need to implement it for cifar-10 images which are 3232. I changed a bit in the architecture but it is causing the above error.

I think having the right size just makes the coding easier. I adjust my images to something divisible by 4. before feeding into my network. I am not familiar with the dataset but your tensor seems weird to me. I am novice as well though, i just do semantic segmentation type stuff and that always has the form [B,C,W,H]. Going into it, if your shapes are right at the start it generally painless going from there.

I watched this guy, when I just starting out he does a decent job IMHO of explaining each part of the UNet.

Good Luck!