Hi. I made Unet model with this code.
import torch
import torch.nn as nn
from collections import OrderedDict
class UNet(nn.Module):
def __init__(self, in_channels, out_channels):
super(UNet, self).__init__()
def ConvBlock(in_channels, out_channels, kernel_size=3, stride=1, padding=1, name=None):
return nn.Sequential(
OrderedDict(
[
(name+"conv1", nn.Conv2d(in_channels, out_channels, kernel_size,
stride, padding, bias=True)),
(name+"bnorm1", nn.BatchNorm2d(num_features=out_channels)),
(name+"relu1", nn.ReLU(inplace=True)),
(name+"conv2",nn.Conv2d(out_channels, out_channels, kernel_size,
stride, padding, bias=True)),
(name+"bnorm1", nn.BatchNorm2d(num_features=out_channels)),
(name+"relu1", nn.ReLU(inplace=True))
]
)
)
self.enc1 = ConvBlock(in_channels, 64, name="enc1")
self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
self.enc2 = ConvBlock(64, 128, name="enc2")
self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
self.enc3 = ConvBlock(128, 256, name="enc3")
self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2)
self.enc4 = ConvBlock(256, 512, name="enc4")
self.pool4 = nn.MaxPool2d(kernel_size=2, stride=2)
self.bottleneck = ConvBlock(512, 1024, name="bottleneck")
self.upconv4 = nn.ConvTranspose2d(1024, 512, kernel_size=2, stride=2)
self.dec4 = ConvBlock(1024, 512, name="dec4")
self.upconv3 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2)
self.dec3 = ConvBlock(512, 256, name="dec3")
self.upconv2 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
self.dec2 = ConvBlock(256, 128, name="dec2")
self.upconv1 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
self.dec1 = ConvBlock(128, 64, name="dec1")
self.out = nn.Conv2d(64, out_channels, kernel_size=1, stride=1, padding=0)
def forward(self, x):
enc1 = self.enc1(x)
enc2 = self.enc2(self.pool1(enc1))
enc3 = self.enc3(self.pool2(enc2))
enc4 = self.enc4(self.pool3(enc3))
bottleneck = self.bottleneck(self.pool4(enc4))
dec4 = self.upconv4(bottleneck)
print(enc4.shape, dec4.shape)
dec4 = torch.cat((dec4, enc4), dim=1)
dec4 = self.dec4(dec4)
dec3 = self.upconv3(dec4)
dec3 = torch.cat((dec3, enc3), dim=1)
dec3 = self.dec3(dec3)
dec2 = self.upconv2(dec3)
dec2 = torch.cat((dec2, enc2), dim=1)
dec2 = self.dec2(dec2)
dec1 = self.upconv1(dec2)
dec1 = torch.cat((dec1, enc1), dim=1)
dec1 = self.dec1(dec1)
return torch.sigmoid(self.conv(dec1))
and I tested it with this code.
import torch
from UNet import UNet
im = torch.randn(1, 1, 572, 572)
unet = UNet(in_channels=1, out_channels=10)
x = unet(im)
print(x)
but I got error below.
/Users/eden/opt/anaconda3/envs/vein/bin/python /Users/eden/PycharmProjects/palmprint-segmentation-pytorch/train.py
torch.Size([1, 512, 71, 71]) torch.Size([1, 512, 70, 70])
Traceback (most recent call last):
File "/Users/eden/PycharmProjects/palmprint-segmentation-pytorch/train.py", line 6, in <module>
x = unet(im)
File "/Users/eden/opt/anaconda3/envs/vein/lib/python3.7/site-packages/torch/nn/modules/module.py", line 722, in _call_impl
result = self.forward(*input, **kwargs)
File "/Users/eden/PycharmProjects/palmprint-segmentation-pytorch/UNet.py", line 59, in forward
dec4 = torch.cat((dec4, enc4), dim=1)
RuntimeError: Sizes of tensors must match except in dimension 1. Got 70 and 71 in dimension 2 (The offending index is 1)
how do I fix it?