I’ve created a SqueezeExcitation embedded Attention UNET and here is my model
# UNET and it's parts
class DoubleConv(nn.Module):
def __init__(self, in_channels, out_channels):
super().__init__()
self.conv_op = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True),
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True)
)
def forward(self, x):
return self.conv_op(x)
class SqueezeExcitation(nn.Module):
def __init__(self, channels, reduction=16):
super().__init__()
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.flat = nn.Flatten()
self.ch = nn.Sequential(
nn.Linear(in_features=channels, out_features=channels // reduction),
nn.ReLU(inplace=True),
nn.Linear(channels // reduction, channels),
nn.Sigmoid()
)
def forward(self, x):
out = self.avg_pool(x)
out = self.flat(out)
out = self.ch(out)
out = out.view(out.size(0), out.size(1), 1, 1)
return x * out
class DownSample(nn.Module):
def __init__(self, in_channels, out_channels):
super().__init__()
self.conv = DoubleConv(in_channels, out_channels)
self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
self.drop = nn.Dropout2d(0.2)
self.se = SqueezeExcitation(in_channels)
def forward(self, x):
down = self.conv(x)
down = self.se(down)
down = self.drop(down)
p = self.pool(down)
return down, p
class UpSample(nn.Module):
def __init__(self, in_channels, out_channels):
super().__init__()
self.up = nn.ConvTranspose2d(in_channels, in_channels//2, kernel_size=2, stride=2)
self.conv = DoubleConv(in_channels, out_channels)
def forward(self, x1, x2):
x1 = self.up(x1)
x = torch.cat([x1, x2], 1)
return self.conv(x)
class Attention(nn.Module):
def __init__(self, in_channels, out_channels):
super().__init__()
self.normal = nn.Conv2d(in_channels, out_channels, kernel_size=1)
self.down = nn.Conv2d(out_channels, out_channels, kernel_size=1, stride=2)
self.one = nn.Conv2d(out_channels, 1, kernel_size=1)
self.relu = nn.ReLU(inplace=True)
self.sigmoid = nn.Sigmoid()
self.resample = nn.Upsample(scale_factor=2)
def forward(self, X, skip_X):
x = self.normal(X)
skip = self.down(skip_X)
x = x + skip
x = self.relu(x)
x = self.one(x)
x = self.sigmoid(x)
x = self.resample(x)
return x * skip_X
class UNet(nn.Module):
def __init__(self, in_channels, num_classes):
super(UNet, self).__init__()
self.down_convolution_1 = DownSample(in_channels, 64)
self.down_convolution_2 = DownSample(64, 128)
self.down_convolution_3 = DownSample(128, 256)
self.down_convolution_4 = DownSample(256, 512)
self.bottle_neck = DoubleConv(512, 1024)
self.attention_gate_1 = Attention(1024, 512)
self.up_convolution_1 = UpSample(1024, 512)
self.attention_gate_2 = Attention(512, 256)
self.up_convolution_2 = UpSample(512, 256)
self.attention_gate_3 = Attention(256, 128)
self.up_convolution_3 = UpSample(256, 128)
self.attention_gate_4 = Attention(128, 64)
self.up_convolution_4 = UpSample(128, 64)
self.out = nn.Conv2d(in_channels=64, out_channels=num_classes, kernel_size=1)
def forward(self, x):
down_1, p1 = self.down_convolution_1(x)
down_2, p2 = self.down_convolution_2(p1)
down_3, p3 = self.down_convolution_3(p2)
down_4, p4 = self.down_convolution_4(p3)
b = self.bottle_neck(p4)
a1 = self.attention_gate_1(b, down_4)
up_1 = self.up_convolution_1(b, a1)
a2 = self.attention_gate_2(up_1, down_3)
up_2 = self.up_convolution_2(up_1, a2)
a3 = self.attention_gate_3(up_2, down_2)
up_3 = self.up_convolution_3(up_2, a3)
a4 = self.attention_gate_4(up_3, down_1)
up_4 = self.up_convolution_4(up_3, a4)
out = self.out(up_4)
return torch.sigmoid(out)
Now hyperparameter BATCH_SIZE = 8, INPUT_SHAPE_UNET=(8,3,256,256)
The part I’ve used to check if my model is working is below
dummy_input = torch.randn((8,3,256,256))
model = UNet(3,1)
model.forward(dummy_input)
Now the error part
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
Cell In[143], line 4
2 dummy_input = torch.randn((8,3,256,256))
3 model = UNet(3,1)
----> 4 model.forward(dummy_input)
Cell In[142], line 106, in UNet.forward(self, x)
105 def forward(self, x):
--> 106 down_1, p1 = self.down_convolution_1(x)
107 down_2, p2 = self.down_convolution_2(p1)
108 down_3, p3 = self.down_convolution_3(p2)
File ~\AppData\Local\Programs\Python\Python311\Lib\site-packages\torch\nn\modules\module.py:1532, in Module._wrapped_call_impl(self, *args, **kwargs)
1530 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1531 else:
-> 1532 return self._call_impl(*args, **kwargs)
File ~\AppData\Local\Programs\Python\Python311\Lib\site-packages\torch\nn\modules\module.py:1541, in Module._call_impl(self, *args, **kwargs)
1536 # If we don't have any hooks, we want to skip the rest of the logic in
1537 # this function, and just call forward.
1538 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1539 or _global_backward_pre_hooks or _global_backward_hooks
1540 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1541 return forward_call(*args, **kwargs)
1543 try:
1544 result = None
Cell In[142], line 44, in DownSample.forward(self, x)
42 def forward(self, x):
43 down = self.conv(x)
---> 44 down = self.se(down)
45 down = self.drop(down)
46 p = self.pool(down)
File ~\AppData\Local\Programs\Python\Python311\Lib\site-packages\torch\nn\modules\module.py:1532, in Module._wrapped_call_impl(self, *args, **kwargs)
1530 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1531 else:
-> 1532 return self._call_impl(*args, **kwargs)
File ~\AppData\Local\Programs\Python\Python311\Lib\site-packages\torch\nn\modules\module.py:1541, in Module._call_impl(self, *args, **kwargs)
1536 # If we don't have any hooks, we want to skip the rest of the logic in
1537 # this function, and just call forward.
1538 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1539 or _global_backward_pre_hooks or _global_backward_hooks
1540 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1541 return forward_call(*args, **kwargs)
1543 try:
1544 result = None
Cell In[142], line 30, in SqueezeExcitation.forward(self, x)
28 out = self.avg_pool(x)
29 out = self.flat(out)
---> 30 out = self.ch(out)
31 out = out.view(out.size(0), out.size(1), 1, 1)
32 return x * out
File ~\AppData\Local\Programs\Python\Python311\Lib\site-packages\torch\nn\modules\module.py:1532, in Module._wrapped_call_impl(self, *args, **kwargs)
1530 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1531 else:
-> 1532 return self._call_impl(*args, **kwargs)
File ~\AppData\Local\Programs\Python\Python311\Lib\site-packages\torch\nn\modules\module.py:1541, in Module._call_impl(self, *args, **kwargs)
1536 # If we don't have any hooks, we want to skip the rest of the logic in
1537 # this function, and just call forward.
1538 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1539 or _global_backward_pre_hooks or _global_backward_hooks
1540 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1541 return forward_call(*args, **kwargs)
1543 try:
1544 result = None
File ~\AppData\Local\Programs\Python\Python311\Lib\site-packages\torch\nn\modules\container.py:217, in Sequential.forward(self, input)
215 def forward(self, input):
216 for module in self:
--> 217 input = module(input)
218 return input
File ~\AppData\Local\Programs\Python\Python311\Lib\site-packages\torch\nn\modules\module.py:1532, in Module._wrapped_call_impl(self, *args, **kwargs)
1530 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1531 else:
-> 1532 return self._call_impl(*args, **kwargs)
File ~\AppData\Local\Programs\Python\Python311\Lib\site-packages\torch\nn\modules\module.py:1541, in Module._call_impl(self, *args, **kwargs)
1536 # If we don't have any hooks, we want to skip the rest of the logic in
1537 # this function, and just call forward.
1538 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1539 or _global_backward_pre_hooks or _global_backward_hooks
1540 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1541 return forward_call(*args, **kwargs)
1543 try:
1544 result = None
File ~\AppData\Local\Programs\Python\Python311\Lib\site-packages\torch\nn\modules\linear.py:116, in Linear.forward(self, input)
115 def forward(self, input: Tensor) -> Tensor:
--> 116 return F.linear(input, self.weight, self.bias)
RuntimeError: mat1 and mat2 shapes cannot be multiplied (8x64 and 3x0)
PLEASE HELP ME to solve the issue.asap @ptrblck and others.