Hi, I have the following snippet in Unet structure:
class DoubleConv(nn.Module): def __init__(self,in_channels, out_channels, mid_channels=None): super(DoubleConv,self).__init__() if not mid_channels: mid_channels = out_channels self.d_conv = nn.Sequential( nn.Conv2d(in_channels,mid_channels,kernel_size=3, padding=1), nn.BatchNorm2d(mid_channels), nn.ReLU(inplace=True), nn.Conv2d(mid_channels,out_channels,kernel_size=3,padding=1), nn.BatchNorm2d(out_channels), nn.ReLU(inplace=True)) def forward(self,x): ## The spatial dim is not changed: (out_dim, h,w) return(self.d_conv(x)) class Up(nn.Module): """Upscaling then double conv""" def __init__(self,in_channels, out_channels, bilinear=True): super(Up,self).__init__() # if bilinear, use the normal convelutions to reduce the number of channels if bilinear: self.up = nn.Upsample(scale_factor=2,mode='bilinear', align_corners=True) # here we devide the number of filters self.conv = DoubleConv(in_channels,out_channels,in_channels//2) else: self.up = nn.ConvTranspose2d(in_channels,in_channels//2,kernel_size=2,stride=2) self.conv = DoubleConv(in_channels,out_channels) def forward(self,x1,x2): x1 = self.up(x1) print(x1.size()) # input is CHW diffy = x2.size() - x1.size() diffx = x2.size()- x1.size() x1 = to_pil_image(x1) x1 = F.pad(x1, [diffx // 2, diffx - diffx // 2,diffy // 2, diffy - diffy // 2]) x1= to_tensor(x1) x = torch.cat([x2, x1], dim=1) return(self.conv(x))
But when I run it to two random tensors, I get two errors:
First: img should be PIL Image. Got <class ‘torch.Tensor’>
while in documentation of pytorch.org it was written that
torchvision.transforms.functional.pad can be applied on both PIL and Tensor images.
When I change the tensors in PIL then I get an error in dimension. While in pytorch.org I read that it dose not matter what dimension you have , just it is important we have […, h,w]. In the following when I put tensors if dimension [batch_num, c,h,w] it gives me the error of size 4 that should be 2 or 3 and when I omit the batch_num from the tensors it returns an error of dimension again.
x2 = torch.randn(1,3,130,130) x1 = torch.randn(1,3,126,126) upp = Up(3,32,True) result = upp(x1,x2)
ValueError: pic should be 2/3 dimensional. Got 4 dimensions.