I want to train my unet model with input shape 128x128x128, but i got an error like this
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
<timed eval> in <module>
/tmp/ipykernel_17/3349595807.py in run(self)
91 def run(self):
92 for epoch in range(self.num_epochs):
---> 93 self._do_epoch(epoch, "train")
94 with torch.no_grad():
95 val_loss = self._do_epoch(epoch, "val")
/tmp/ipykernel_17/3349595807.py in _do_epoch(self, epoch, phase)
66 for itr, data_batch in enumerate(dataloader):
67 images, targets = data_batch['image'], data_batch['mask']
---> 68 loss, logits = self._compute_loss_and_outputs(images, targets)
69 loss = loss / self.accumulation_steps
70 if phase == "train":
/tmp/ipykernel_17/3349595807.py in _compute_loss_and_outputs(self, images, targets)
51 images = images.to(self.device)
52 targets = targets.to(self.device)
---> 53 logits = self.net(images)
54 loss = self.criterion(logits, targets)
55 return loss, logits
/opt/conda/lib/python3.7/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
1108 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
1109 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1110 return forward_call(*input, **kwargs)
1111 # Do not call functions when jit is used
1112 full_backward_hooks, non_full_backward_hooks = [], []
/tmp/ipykernel_17/948357362.py in forward(self, x)
138 # Level 2 context pathway
139 out = self.conv3d_c2(out)
--> 140 out = self.SE_c2(out)
141 residual_2 = out
142 out = self.norm_lrelu_conv_c2(out)
/opt/conda/lib/python3.7/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
1108 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
1109 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1110 return forward_call(*input, **kwargs)
1111 # Do not call functions when jit is used
1112 full_backward_hooks, non_full_backward_hooks = [], []
/tmp/ipykernel_17/673886512.py in forward(self, input_tensor)
116 :return: output_tensor
117 """
--> 118 output_tensor = torch.max(self.cSE(input_tensor), self.sSE(input_tensor))
119 return output_tensor
/opt/conda/lib/python3.7/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
1108 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
1109 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1110 return forward_call(*input, **kwargs)
1111 # Do not call functions when jit is used
1112 full_backward_hooks, non_full_backward_hooks = [], []
/tmp/ipykernel_17/673886512.py in forward(self, input_tensor)
38
39 # channel excitation
---> 40 fc_out_1 = self.relu(self.fc1(squeeze_tensor.view(batch_size, num_channels)))
41 print(f"fc_out_1 shape: {fc_out_1.shape}")
42 fc_out_2 = self.sigmoid(self.fc2(fc_out_1))
/opt/conda/lib/python3.7/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
1108 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
1109 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1110 return forward_call(*input, **kwargs)
1111 # Do not call functions when jit is used
1112 full_backward_hooks, non_full_backward_hooks = [], []
/opt/conda/lib/python3.7/site-packages/torch/nn/modules/linear.py in forward(self, input)
101
102 def forward(self, input: Tensor) -> Tensor:
--> 103 return F.linear(input, self.weight, self.bias)
104
105 def extra_repr(self) -> str:
RuntimeError: mat1 and mat2 shapes cannot be multiplied (3x16 and 8x4)
With my model looks like this
import torch
from torch import nn as nn
from torch.nn import functional as F
class ChannelSELayer3D(nn.Module):
"""
3D extension of Squeeze-and-Excitation (SE) block described in:
*Hu et al., Squeeze-and-Excitation Networks, arXiv:1709.01507*
*Zhu et al., AnatomyNet, arXiv:arXiv:1808.05238*
"""
def __init__(self, num_channels, reduction_ratio=2, norm='None'):
"""
:param num_channels: No of input channels
:param reduction_ratio: By how much should the num_channels should be reduced
"""
super(ChannelSELayer3D, self).__init__()
self.avg_pool = nn.AdaptiveAvgPool3d(1)
num_channels_reduced = num_channels // reduction_ratio
self.reduction_ratio = reduction_ratio
self.fc1 = nn.Linear(num_channels, num_channels_reduced, bias=True)
self.fc2 = nn.Linear(num_channels_reduced, num_channels, bias=True)
self.relu = nn.ReLU()
self.sigmoid = nn.Sigmoid()
self.norm = norm
self.bn = nn.BatchNorm3d(num_channels)
def forward(self, input_tensor):
"""
:param input_tensor: X, shape = (batch_size, num_channels, D, H, W)
:return: output tensor
"""
batch_size, num_channels, D, H, W = input_tensor.size()
# Average along each channel
squeeze_tensor = self.avg_pool(input_tensor)
# channel excitation
fc_out_1 = self.relu(self.fc1(squeeze_tensor.view(batch_size, num_channels)))
print(f"fc_out_1 shape: {fc_out_1.shape}")
fc_out_2 = self.sigmoid(self.fc2(fc_out_1))
print(f"fc_out_2 shape: {fc_out_2.shape}")
output_tensor = torch.mul(input_tensor, fc_out_2.view(batch_size, num_channels, 1, 1, 1))
if self.norm == 'BN':
output_tensor = self.bn(output_tensor)
return output_tensor
class SpatialSELayer3D(nn.Module):
"""
3D extension of SE block -- squeezing spatially and exciting channel-wise described in:
*Roy et al., Concurrent Spatial and Channel Squeeze & Excitation in Fully Convolutional Networks, MICCAI 2018*
"""
def __init__(self, num_channels, norm = 'None'):
"""
:param num_channels: No of input channels
"""
super(SpatialSELayer3D, self).__init__()
self.conv = nn.Conv3d(num_channels, 1, 1)
self.sigmoid = nn.Sigmoid()
self.norm = norm
self.bn = nn.BatchNorm3d(1)
def forward(self, input_tensor, weights=None):
"""
:param weights: weights for few shot learning
:param input_tensor: X, shape = (batch_size, num_channels, D, H, W)
:return: output_tensor
"""
# channel squeeze
batch_size, channel, D, H, W = input_tensor.size()
if weights:
weights = weights.view(1, channel, 1, 1)
out = F.conv2d(input_tensor, weights)
else:
out = self.conv(input_tensor)
if self.norm == 'BN':
out = self.bn(out)
squeeze_tensor = self.sigmoid(out)
# spatial excitation
output_tensor = torch.mul(input_tensor, squeeze_tensor.view(batch_size, 1, D, H, W))
return output_tensor
class ChannelSpatialSELayer3D(nn.Module):
"""
3D extension of concurrent spatial and channel squeeze & excitation:
*Roy et al., Concurrent Spatial and Channel Squeeze & Excitation in Fully Convolutional Networks, arXiv:1803.02579*
"""
def __init__(self, num_channels, reduction_ratio=2, norm='None'):
"""
:param num_channels: No of input channels
:param reduction_ratio: By how much should the num_channels should be reduced
"""
super(ChannelSpatialSELayer3D, self).__init__()
self.cSE = ChannelSELayer3D(num_channels, reduction_ratio, norm=norm)
self.sSE = SpatialSELayer3D(num_channels, norm=norm)
self.norm = norm
def forward(self, input_tensor):
"""
:param input_tensor: X, shape = (batch_size, num_channels, D, H, W)
:return: output_tensor
"""
output_tensor = torch.max(self.cSE(input_tensor), self.sSE(input_tensor))
return output_tensor
class Modified3DUNet(nn.Module):
def __init__(self, in_channels=4, n_classes=3, base_n_filter = 8):
super(Modified3DUNet, self).__init__()
self.in_channels = in_channels
self.n_classes = n_classes
self.base_n_filter = base_n_filter
output_channels = base_n_filter
filters = [4, 8, 16, 32, 64]
self.lrelu = nn.LeakyReLU()
self.dropout3d = nn.Dropout3d(p=0.6)
self.upsacle = nn.Upsample(scale_factor=4, mode='nearest')
self.softmax = nn.Softmax(dim=1)
# Level 1 context pathway
# self.conv3d_c1_1 = nn.Conv3d(self.in_channels, self.base_n_filter, kernel_size=3, stride=1, padding=1, bias=False)
self.conv3d_c1_1 = DepthwiseSeparableConv3d(self.in_channels, output_channels, kernel_size=3, padding=1, stride=1, dilation=1, groups=1, bias=False, kernels_per_layer=1)
# self.conv3d_c1_2 = nn.Conv3d(self.base_n_filter, self.base_n_filter, kernel_size=3, stride=1, padding=1, bias=False)
self.conv3d_c1_2 = DepthwiseSeparableConv3d(self.base_n_filter, self.base_n_filter, kernel_size=3, padding=1, stride=1, dilation=1, groups=1, bias=False, kernels_per_layer=1)
self.SE_c1 = ChannelSpatialSELayer3D(self.base_n_filter*1, norm='None')
self.lrelu_conv_c1 = self.lrelu_conv(self.base_n_filter, self.base_n_filter)
self.inorm3d_c1 = nn.InstanceNorm3d(self.base_n_filter)
# Level 2 context pathway
# self.conv3d_c2 = nn.Conv3d(self.base_n_filter, self.base_n_filter*2, kernel_size=3, stride=2, padding=1, bias=False)
self.conv3d_c2 = DepthwiseSeparableConv3d(self.base_n_filter, self.base_n_filter*2, kernel_size=3, padding=1, stride=2, dilation=1, groups=1, bias=False, kernels_per_layer=1)
self.SE_c2 = ChannelSpatialSELayer3D(self.base_n_filter, norm='None')
self.norm_lrelu_conv_c2 = self.norm_lrelu_conv(self.base_n_filter*2, self.base_n_filter*2)
self.inorm3d_c2 = nn.InstanceNorm3d(self.base_n_filter*2)
# Level 3 context pathway
# self.conv3d_c3 = nn.Conv3d(self.base_n_filter*2, self.base_n_filter*4, kernel_size=3, stride=2, padding=1, bias=False)
self.conv3d_c3 = DepthwiseSeparableConv3d(self.base_n_filter*2, self.base_n_filter*4, kernel_size=3, padding=1, stride=1, dilation=1, groups=1, bias=False, kernels_per_layer=1)
self.SE_c3 = ChannelSpatialSELayer3D(self.base_n_filter*2, norm='None')
self.norm_lrelu_conv_c3 = self.norm_lrelu_conv(self.base_n_filter*4, self.base_n_filter*4)
self.inorm3d_c3 = nn.InstanceNorm3d(self.base_n_filter*4)
# Level 4 context pathway
# self.conv3d_c4 = nn.Conv3d(self.base_n_filter*4, self.base_n_filter*8, kernel_size=3, stride=2, padding=1, bias=False)
self.conv3d_c4 = DepthwiseSeparableConv3d(self.base_n_filter*4, self.base_n_filter*8, kernel_size=2, padding=1, stride=1, dilation=1, groups=1, bias=False, kernels_per_layer=1)
self.SE_c4 = ChannelSpatialSELayer3D(self.base_n_filter*4, norm='None')
self.norm_lrelu_conv_c4 = self.norm_lrelu_conv(self.base_n_filter*8, self.base_n_filter*8)
self.inorm3d_c4 = nn.InstanceNorm3d(self.base_n_filter*8)
# Level 5 context pathway, level 0 localization pathway
# self.conv3d_c5 = nn.Conv3d(self.base_n_filter*8, self.base_n_filter*16, kernel_size=3, stride=2, padding=1, bias=False)
self.conv3d_c5 = DepthwiseSeparableConv3d(self.base_n_filter*8, self.base_n_filter*16, kernel_size=3, padding=0, stride=2, dilation=1, groups=1, bias=False, kernels_per_layer=1)
self.SE_c5 = ChannelSpatialSELayer3D(self.base_n_filter*8, norm='None')
self.norm_lrelu_conv_c5 = self.norm_lrelu_conv(self.base_n_filter*16, self.base_n_filter*16)
self.norm_lrelu_upscale_conv_norm_lrelu_l0 = self.norm_lrelu_upscale_conv_norm_lrelu(self.base_n_filter*16, self.base_n_filter*8)
# self.conv3d_l0 = nn.Conv3d(self.base_n_filter*8, self.base_n_filter*8, kernel_size = 1, stride=1, padding=0, bias=False)
self.conv3d_l0 = DepthwiseSeparableConv3d(self.base_n_filter*8, self.base_n_filter*8, kernel_size=1, padding=1, stride=1, dilation=1, groups=1, bias=False, kernels_per_layer=1)
self.SE_l0 = ChannelSpatialSELayer3D(self.base_n_filter*8, norm='None')
self.inorm3d_l0 = nn.InstanceNorm3d(self.base_n_filter*8)
# Level 1 localization pathway
self.conv_norm_lrelu_l1 = self.conv_norm_lrelu(self.base_n_filter*16, self.base_n_filter*16)
# self.conv3d_l1 = nn.Conv3d(self.base_n_filter*16, self.base_n_filter*8, kernel_size=1, stride=1, padding=0, bias=False)
self.conv3d_l1 = DepthwiseSeparableConv3d(self.base_n_filter*16, self.base_n_filter*8, kernel_size=1, padding=1, stride=1, dilation=1, groups=1, bias=False, kernels_per_layer=1)
self.SE_l1 = ChannelSpatialSELayer3D(self.base_n_filter*16, norm='None')
self.norm_lrelu_upscale_conv_norm_lrelu_l1 = self.norm_lrelu_upscale_conv_norm_lrelu(self.base_n_filter*8, self.base_n_filter*4)
# Level 2 localization pathway
self.conv_norm_lrelu_l2 = self.conv_norm_lrelu(self.base_n_filter*8, self.base_n_filter*8)
# self.conv3d_l2 = nn.Conv3d(self.base_n_filter*8, self.base_n_filter*4, kernel_size=1, stride=1, padding=0, bias=False)
self.conv3d_l2 = DepthwiseSeparableConv3d(self.base_n_filter*8, self.base_n_filter*4, kernel_size=1, padding=1, stride=1, dilation=1, groups=1, bias=False, kernels_per_layer=1)
self.SE_l2 = ChannelSpatialSELayer3D(self.base_n_filter*8, norm='None')
self.norm_lrelu_upscale_conv_norm_lrelu_l2 = self.norm_lrelu_upscale_conv_norm_lrelu(self.base_n_filter*4, self.base_n_filter*2)
# Level 3 localization pathway
self.conv_norm_lrelu_l3 = self.conv_norm_lrelu(self.base_n_filter*4, self.base_n_filter*4)
# self.conv3d_l3 = nn.Conv3d(self.base_n_filter*4, self.base_n_filter*2, kernel_size=1, stride=1, padding=0, bias=False)
self.conv3d_l3 = DepthwiseSeparableConv3d(self.base_n_filter*4, self.base_n_filter*2, kernel_size=1, padding=1, stride=1, dilation=1, groups=1, bias=False, kernels_per_layer=1)
self.SE_l3 = ChannelSpatialSELayer3D(self.base_n_filter*4, norm='None')
self.norm_lrelu_upscale_conv_norm_lrelu_l3 = self.norm_lrelu_upscale_conv_norm_lrelu(self.base_n_filter*2, self.base_n_filter)
# Level 4 localization pathway
self.conv_norm_lrelu_l4 = self.conv_norm_lrelu(self.base_n_filter*2, self.base_n_filter*2)
# self.conv3d_l4 = nn.Conv3d(self.base_n_filter*2, self.n_classes, kernel_size=1, stride=1, padding=0, bias=False)
self.conv3d_l4 = DepthwiseSeparableConv3d(self.base_n_filter*2, self.n_classes, kernel_size=1, padding=1, stride=1, dilation=1, groups=1, bias=False, kernels_per_layer=1)
self.SE_l4 = ChannelSpatialSELayer3D(self.base_n_filter*2, norm='None')
self.ds2_1x1_conv3d = nn.Conv3d(self.base_n_filter*8, self.n_classes, kernel_size=1, stride=1, padding=0, bias=False)
# self.ds2_1x1_conv3d = DepthwiseSeparableConv3d(self.base_n_filter*8, self.n_classes, kernel_size=1, padding=1, stride=1, dilation=1, groups=1, bias=False, kernels_per_layer=1)
self.ds3_1x1_conv3d = nn.Conv3d(self.base_n_filter*4, self.n_classes, kernel_size=1, stride=1, padding=0, bias=False)
# self.ds3_1x1_conv3d = DepthwiseSeparableConv3d(self.base_n_filter*4, self.n_classes, kernel_size=1, padding=1, stride=1, dilation=1, groups=1, bias=False, kernels_per_layer=1)
def conv_norm_lrelu(self, feat_in, feat_out):
return nn.Sequential(
nn.Conv3d(feat_in, feat_out, kernel_size=3, stride=1, padding=1, bias=False),
nn.InstanceNorm3d(feat_out),
nn.LeakyReLU())
def norm_lrelu_conv(self, feat_in, feat_out):
return nn.Sequential(
nn.InstanceNorm3d(feat_in),
nn.LeakyReLU(),
nn.Conv3d(feat_in, feat_out, kernel_size=3, stride=1, padding=1, bias=False))
def lrelu_conv(self, feat_in, feat_out):
return nn.Sequential(
nn.LeakyReLU(),
nn.Conv3d(feat_in, feat_out, kernel_size=3, stride=1, padding=1, bias=False))
def norm_lrelu_upscale_conv_norm_lrelu(self, feat_in, feat_out):
return nn.Sequential(
nn.InstanceNorm3d(feat_in),
nn.LeakyReLU(),
nn.Upsample(scale_factor=2, mode='nearest'),
# should be feat_in*2 or feat_in
nn.Conv3d(feat_in, feat_out, kernel_size=3, stride=1, padding=1, bias=False),
nn.InstanceNorm3d(feat_out),
nn.LeakyReLU())
def forward(self, x):
# Level 1 context pathway
out = self.conv3d_c1_1(x)
residual_1 = out
out = self.lrelu(out)
out = self.conv3d_c1_2(out)
out = self.SE_c1(out)
out = self.dropout3d(out)
out = self.lrelu_conv_c1(out)
# Element Wise Summation
out += residual_1
context_1 = self.lrelu(out)
out = self.inorm3d_c1(out)
out = self.lrelu(out)
# Level 2 context pathway
out = self.conv3d_c2(out)
out = self.SE_c2(out)
residual_2 = out
out = self.norm_lrelu_conv_c2(out)
out = self.dropout3d(out)
out = self.norm_lrelu_conv_c2(out)
out += residual_2
out = self.inorm3d_c2(out)
out = self.lrelu(out)
context_2 = out
# Level 3 context pathway
out = self.conv3d_c3(out)
out = self.SE_c3(out)
residual_3 = out
out = self.norm_lrelu_conv_c3(out)
out = self.dropout3d(out)
out = self.norm_lrelu_conv_c3(out)
out += residual_3
out = self.inorm3d_c3(out)
out = self.lrelu(out)
context_3 = out
# Level 4 context pathway
out = self.conv3d_c4(out)
out = self.SE_c4(out)
residual_4 = out
out = self.norm_lrelu_conv_c4(out)
out = self.dropout3d(out)
out = self.norm_lrelu_conv_c4(out)
out += residual_4
out = self.inorm3d_c4(out)
out = self.lrelu(out)
context_4 = out
# Level 5 context pathway
out = self.conv3d_c5(out)
out = self.SE_c5(out)
residual_5 = out
out = self.norm_lrelu_conv_c5(out)
out = self.dropout3d(out)
out = self.norm_lrelu_conv_c5(out)
out += residual_5
out = self.norm_lrelu_upscale_conv_norm_lrelu_l0(out)
out = self.conv3d_l0(out)
out = self.inorm3d_l0(out)
out = self.lrelu(out)
if out.size()[2:] != context_4.size()[2:]:
out = F.interpolate(out, size=context_4.size()[2:], mode='trilinear', align_corners=True)
# Level 1 localization pathway
out = torch.cat([out, context_4], dim=1)
# out = self.attn1(out, context_4)
# out = torch.cat([out, context_4], dim=1)
out = self.conv_norm_lrelu_l1(out)
out = self.conv3d_l1(out)
out = self.SE_l1(out)
out = self.norm_lrelu_upscale_conv_norm_lrelu_l1(out)
if out.size()[2:] != context_3.size()[2:]:
out = F.interpolate(out, size=context_3.size()[2:], mode='trilinear', align_corners=True)
# Level 2 localization pathway
out = torch.cat([out, context_3], dim=1)
# out = self.attn1(out, context_3)
out = self.conv_norm_lrelu_l2(out)
ds2 = out
out = self.conv3d_l2(out)
out = self.SE_l2(out)
out = self.norm_lrelu_upscale_conv_norm_lrelu_l2(out)
if out.size()[2:] != context_2.size()[2:]:
out = F.interpolate(out, size=context_2.size()[2:], mode='trilinear', align_corners=True)
# Level 3 localization pathway
out = torch.cat([out, context_2], dim=1)
# out = self.attn1(out, context_2)
out = self.conv_norm_lrelu_l3(out)
ds3 = out
out = self.conv3d_l3(out)
out = self.SE_l3(out)
out = self.norm_lrelu_upscale_conv_norm_lrelu_l3(out)
if out.size()[2:] != context_1.size()[2:]:
out = F.interpolate(out, size=context_1.size()[2:], mode='trilinear', align_corners=True)
# Level 4 localization pathway
out = torch.cat([out, context_1], dim=1)
# out = self.attn1(out, context_1)
out = self.conv_norm_lrelu_l4(out)
out = self.SE_l4(out)
out_pred = self.conv3d_l4(out)
ds2_1x1_conv = self.ds2_1x1_conv3d(ds2)
ds1_ds2_sum_upscale = self.upsacle(ds2_1x1_conv)
ds3_1x1_conv = self.ds3_1x1_conv3d(ds3)
if ds1_ds2_sum_upscale.size()[2:] != ds3_1x1_conv.size()[2:]:
# Resize ds1_ds2_sum_upscale to match the spatial dimensions of ds3_1x1_conv
ds1_ds2_sum_upscale = F.interpolate(ds1_ds2_sum_upscale, size=ds3_1x1_conv.size()[2:], mode='trilinear', align_corners=True)
ds1_ds2_sum_upscale_ds3_sum = ds1_ds2_sum_upscale + ds3_1x1_conv
ds1_ds2_sum_upscale_ds3_sum_upscale = self.upsacle(ds1_ds2_sum_upscale_ds3_sum)
if out_pred.size()[2:] != ds1_ds2_sum_upscale_ds3_sum_upscale.size()[2:]:
# Resize ds1_ds2_sum_upscale to match the spatial dimensions of ds3_1x1_conv
out_pred = F.interpolate(out_pred, size=ds1_ds2_sum_upscale_ds3_sum_upscale.size()[2:], mode='trilinear', align_corners=True)
out = out_pred + ds1_ds2_sum_upscale_ds3_sum_upscale
return out
#seg_layer = out
#out = out.permute(0, 2, 3, 4, 1).contiguous().view(-1, self.n_classes)
#out = self.softmax(out)
#return out, seg_layer