type or p#------------------------------------------------------------------------------#
import torch
import torch.nn as nn
import torch.nn.functional as F
#------------------------------------------------------------------------------#
class DoubleConv(nn.Module):
def __init__(self, in_channels, out_channels):
super().__init__()
self.double_conv = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True),
)
def forward(self, x):
return self.double_conv(x)
#-----------------------------------------------------------------------------#
class conv1_cross_1(nn.Module):
def __init__(self,in_channels,out_channels):
super().__init__()
self.conv1x1 = nn.Sequential (
nn.Conv2d(in_channels, out_channels, kernel_size=1,padding=0)
)
def forward(self, x):
return self.conv1x1(x)
#--------------------------------------------------------------------------------#
class Down(nn.Module):
def __init__(self, in_channels, out_channels):
super().__init__()
self.maxpool_conv = nn.Sequential(
nn.MaxPool2d(2)
)
def forward(self, x):
return self.maxpool_conv(x)
#------------------------------------------------------------------------------#
class Up(nn.Module):
def __init__(self, in_channels, out_channels):
super().__init__()
# print (in_channels, out_channels)
self.up_sampling = nn.Sequential (
nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2, padding=0),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True))
def forward(self, x):
x = self.up_sampling(x)
return x
#-------------------------------------------------------------------------------#
class OutConv(nn.Module):
def __init__(self, in_channels, out_channels):
super(OutConv, self).__init__()
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)
def forward(self, x):
return self.conv(x)
#------------------------------------------------------------------------------#
class UNet(nn.Module):
def __init__(self, n_channels, n_classes):
super(UNet, self).__init__()
# self.n_channels = n_channels
# self.n_classes = n_classes
#------------------------------#
# ENCODER #
#------------------------------#
self.inc1 = DoubleConv(1,96)
self.down1 = Down(96,96)
self.inc2 = DoubleConv(96,96)
self.down2 = Down(96,96)
self.inc3 = DoubleConv(96,96)
self.down3 = Down(96,96)
self.inc4 = DoubleConv(96,96)
self.down4 = Down(96,96)
#------------------------------#
# BOTTLENECK #
#------------------------------#
self.bot1=conv1_cross_1 (96,96)
#-------------------------------#
# UP #
#-------------------------------#
self.up1 = Up(96,96)
self.up2 = Up(96,96)
self.up3 = Up(96,96)
self.up4 = Up(96,96)
self.outc = OutConv(96,3)
def forward(self, x):
x1 = self.inc1(x)
x2 = self.down1(x1) # maxpooling
#--------------------------------------------------------#
x3 = self.inc2(x2)
x4 = self.down2 (x3) # maxpooling
#--------------------------------------------------------#
x5 = self.inc3(x4)
x6 = self.down3(x5) # maxpooling
#--------------------------------------------------------#
x7 = self.inc4(x6)
x8 = self.down4 (x7) # maxpooling
#--------------------------------------------------------#
x9 = self.bot1(x8)
#--------------------------------------------------------#
x = self.up1(x9)
x = torch.cat([x, x7], dim=1)
print (x.shape)
#--------------------------------------------------------#
x = self.up2 (x)
x = torch.cat([x, x5], dim=1)
print (x.shape)
#--------------------------------------------------------#
x = self.up3 (x)
x = torch.cat([x, x3], dim=1)
#--------------------------------------------------------#
x = self.up4 (x)
x = torch.cat([x, x1], dim=1)
#--------------------------------------------------------#
logits = self.outc(x)
return logits
# def use_checkpointing(self):
# self.inc = torch.utils.checkpoint(self.inc)
# self.down1 = torch.utils.checkpoint(self.down1)
# self.down2 = torch.utils.checkpoint(self.down2)
# self.down3 = torch.utils.checkpoint(self.down3)
# self.down4 = torch.utils.checkpoint(self.down4)
# self.up1 = torch.utils.checkpoint(self.up1)
# self.up2 = torch.utils.checkpoint(self.up2)
# self.up3 = torch.utils.checkpoint(self.up3)
# self.up4 = torch.utils.checkpoint(self.up4)
# self.outc = torch.utils.checkpoint(self.outc)
#------------------------------------------------------------------------------#
if __name__ == "__main__":
# inputs = torch.randn([1,1,256,256])
# image = cv2.imread (r'C:\Users\Idrees Bhat\Desktop\Research\Insha\Dataset\o_gray\1.png',0)
# convert_tensor = transforms.ToTensor()
# inputs=convert_tensor(image)
# print (inputs.shape)
# print (type(inputs)) # [1,256,256]
# inputs = torch.randn((2,1,256,256))
# y = conv_block(1,96)
# y(inputs)
# e = encoder_block(1, 96)
# x,p = e(inputs)
# print (x.shape, p.shape)
inputs = torch.randn((1,1,256,256))
# print (inputs.shape)
# y = conv1_cross_1(1, 96)
# y = Down(1, 96)
# y = DoubleConv (1,96)
# y(inputs)
# y = Up(1,96)
# y(inputs)
# y = OutConv(1,3)
# y(inputs)
# d = conv1_cross_1 (96,1)
# d(inputs)
d = UNet(1,96)
d(inputs)
# skip = torch. randn((2,32,512,512))
# d = decoder_block(64, 96)
# e = conv1_cross_1(96,96)
# print (e)
# d = build_unet()
# d(inputs)
# e(inputs)
# print (inputs.shape,skip.shape,x.shape)
#---------------------------------------------------#
Error:
runfile(‘C:/Users/Idrees Bhat/Desktop/insha/this_time.py’, wdir=‘C:/Users/Idrees Bhat/Desktop/insha’)
torch.Size([1, 192, 32, 32])
Traceback (most recent call last):
File ~\AppData\Local\anaconda3\Lib\site-packages\spyder_kernels\py3compat.py:356 in compat_exec
exec(code, globals, locals)
File c:\users\idrees bhat\desktop\insha\this_time.py:180
d(inputs)
File ~\AppData\Local\anaconda3\Lib\site-packages\torch\nn\modules\module.py:1518 in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File ~\AppData\Local\anaconda3\Lib\site-packages\torch\nn\modules\module.py:1527 in _call_impl
return forward_call(*args, **kwargs)
File c:\users\idrees bhat\desktop\insha\this_time.py:126 in forward
x = self.up2 (x)
File ~\AppData\Local\anaconda3\Lib\site-packages\torch\nn\modules\module.py:1518 in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File ~\AppData\Local\anaconda3\Lib\site-packages\torch\nn\modules\module.py:1527 in _call_impl
return forward_call(*args, **kwargs)
File c:\users\idrees bhat\desktop\insha\this_time.py:52 in forward
x = self.up_sampling(x)
File ~\AppData\Local\anaconda3\Lib\site-packages\torch\nn\modules\module.py:1518 in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File ~\AppData\Local\anaconda3\Lib\site-packages\torch\nn\modules\module.py:1527 in _call_impl
return forward_call(*args, **kwargs)
File ~\AppData\Local\anaconda3\Lib\site-packages\torch\nn\modules\container.py:215 in forward
input = module(input)
File ~\AppData\Local\anaconda3\Lib\site-packages\torch\nn\modules\module.py:1518 in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File ~\AppData\Local\anaconda3\Lib\site-packages\torch\nn\modules\module.py:1527 in _call_impl
return forward_call(*args, **kwargs)
File ~\AppData\Local\anaconda3\Lib\site-packages\torch\nn\modules\conv.py:952 in forward
return F.conv_transpose2d(
RuntimeError: Given transposed=1, weight of size [96, 96, 2, 2], expected input[1, 192, 32, 32] to have 96 channels, but got 192 channels instead
#------------------------------------------------------------------------------------------------------------------------#