RuntimeError: mat1 and mat2 shapes cannot be multiplied (1x16 and 8x4)

I want to train my unet model with input shape 128x128x128, but i got an error like this

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<timed eval> in <module>

/tmp/ipykernel_17/3349595807.py in run(self)
     91     def run(self):
     92         for epoch in range(self.num_epochs):
---> 93             self._do_epoch(epoch, "train")
     94             with torch.no_grad():
     95                 val_loss = self._do_epoch(epoch, "val")

/tmp/ipykernel_17/3349595807.py in _do_epoch(self, epoch, phase)
     66         for itr, data_batch in enumerate(dataloader):
     67             images, targets = data_batch['image'], data_batch['mask']
---> 68             loss, logits = self._compute_loss_and_outputs(images, targets)
     69             loss = loss / self.accumulation_steps
     70             if phase == "train":

/tmp/ipykernel_17/3349595807.py in _compute_loss_and_outputs(self, images, targets)
     51         images = images.to(self.device)
     52         targets = targets.to(self.device)
---> 53         logits = self.net(images)
     54         loss = self.criterion(logits, targets)
     55         return loss, logits

/opt/conda/lib/python3.7/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
   1108         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1109                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1110             return forward_call(*input, **kwargs)
   1111         # Do not call functions when jit is used
   1112         full_backward_hooks, non_full_backward_hooks = [], []

/tmp/ipykernel_17/948357362.py in forward(self, x)
    138         # Level 2 context pathway
    139         out = self.conv3d_c2(out)
--> 140         out = self.SE_c2(out)
    141         residual_2 = out
    142         out = self.norm_lrelu_conv_c2(out)

/opt/conda/lib/python3.7/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
   1108         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1109                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1110             return forward_call(*input, **kwargs)
   1111         # Do not call functions when jit is used
   1112         full_backward_hooks, non_full_backward_hooks = [], []

/tmp/ipykernel_17/673886512.py in forward(self, input_tensor)
    116         :return: output_tensor
    117         """
--> 118         output_tensor = torch.max(self.cSE(input_tensor), self.sSE(input_tensor))
    119         return output_tensor

/opt/conda/lib/python3.7/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
   1108         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1109                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1110             return forward_call(*input, **kwargs)
   1111         # Do not call functions when jit is used
   1112         full_backward_hooks, non_full_backward_hooks = [], []

/tmp/ipykernel_17/673886512.py in forward(self, input_tensor)
     38 
     39         # channel excitation
---> 40         fc_out_1 = self.relu(self.fc1(squeeze_tensor.view(batch_size, num_channels)))
     41         print(f"fc_out_1 shape: {fc_out_1.shape}")
     42         fc_out_2 = self.sigmoid(self.fc2(fc_out_1))

/opt/conda/lib/python3.7/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
   1108         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1109                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1110             return forward_call(*input, **kwargs)
   1111         # Do not call functions when jit is used
   1112         full_backward_hooks, non_full_backward_hooks = [], []

/opt/conda/lib/python3.7/site-packages/torch/nn/modules/linear.py in forward(self, input)
    101 
    102     def forward(self, input: Tensor) -> Tensor:
--> 103         return F.linear(input, self.weight, self.bias)
    104 
    105     def extra_repr(self) -> str:

RuntimeError: mat1 and mat2 shapes cannot be multiplied (3x16 and 8x4)

With my model looks like this

import torch
from torch import nn as nn
from torch.nn import functional as F


class ChannelSELayer3D(nn.Module):
    """
    3D extension of Squeeze-and-Excitation (SE) block described in:
        *Hu et al., Squeeze-and-Excitation Networks, arXiv:1709.01507*
        *Zhu et al., AnatomyNet, arXiv:arXiv:1808.05238*
    """

    def __init__(self, num_channels, reduction_ratio=2, norm='None'):
        """
        :param num_channels: No of input channels
        :param reduction_ratio: By how much should the num_channels should be reduced
        """
        super(ChannelSELayer3D, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool3d(1)
        num_channels_reduced = num_channels // reduction_ratio
        self.reduction_ratio = reduction_ratio
        self.fc1 = nn.Linear(num_channels, num_channels_reduced, bias=True)
        self.fc2 = nn.Linear(num_channels_reduced, num_channels, bias=True)
        self.relu = nn.ReLU()
        self.sigmoid = nn.Sigmoid()
        self.norm = norm
        self.bn = nn.BatchNorm3d(num_channels)

    def forward(self, input_tensor):
        """
        :param input_tensor: X, shape = (batch_size, num_channels, D, H, W)
        :return: output tensor
        """
        batch_size, num_channels, D, H, W = input_tensor.size()
        # Average along each channel
        squeeze_tensor = self.avg_pool(input_tensor)

        # channel excitation
        fc_out_1 = self.relu(self.fc1(squeeze_tensor.view(batch_size, num_channels)))
        print(f"fc_out_1 shape: {fc_out_1.shape}")
        fc_out_2 = self.sigmoid(self.fc2(fc_out_1))
        print(f"fc_out_2 shape: {fc_out_2.shape}")

        output_tensor = torch.mul(input_tensor, fc_out_2.view(batch_size, num_channels, 1, 1, 1))

        if self.norm == 'BN':
            output_tensor = self.bn(output_tensor)


        return output_tensor


class SpatialSELayer3D(nn.Module):
    """
    3D extension of SE block -- squeezing spatially and exciting channel-wise described in:
        *Roy et al., Concurrent Spatial and Channel Squeeze & Excitation in Fully Convolutional Networks, MICCAI 2018*
    """

    def __init__(self, num_channels, norm = 'None'):
        """
        :param num_channels: No of input channels

        """
        super(SpatialSELayer3D, self).__init__()
        self.conv = nn.Conv3d(num_channels, 1, 1)
        self.sigmoid = nn.Sigmoid()
        self.norm = norm
        self.bn = nn.BatchNorm3d(1)

    def forward(self, input_tensor, weights=None):
        """
        :param weights: weights for few shot learning
        :param input_tensor: X, shape = (batch_size, num_channels, D, H, W)
        :return: output_tensor
        """
        # channel squeeze
        batch_size, channel, D, H, W = input_tensor.size()

        if weights:
            weights = weights.view(1, channel, 1, 1)
            out = F.conv2d(input_tensor, weights)
        else:
            out = self.conv(input_tensor)

            if self.norm == 'BN':
                out = self.bn(out)

        squeeze_tensor = self.sigmoid(out)

        # spatial excitation
        output_tensor = torch.mul(input_tensor, squeeze_tensor.view(batch_size, 1, D, H, W))

        return output_tensor


class ChannelSpatialSELayer3D(nn.Module):
    """
       3D extension of concurrent spatial and channel squeeze & excitation:
           *Roy et al., Concurrent Spatial and Channel Squeeze & Excitation in Fully Convolutional Networks, arXiv:1803.02579*
       """

    def __init__(self, num_channels, reduction_ratio=2, norm='None'):
        """
        :param num_channels: No of input channels
        :param reduction_ratio: By how much should the num_channels should be reduced
        """
        super(ChannelSpatialSELayer3D, self).__init__()
        self.cSE = ChannelSELayer3D(num_channels, reduction_ratio, norm=norm)
        self.sSE = SpatialSELayer3D(num_channels, norm=norm)
        self.norm = norm

    def forward(self, input_tensor):
        """
        :param input_tensor: X, shape = (batch_size, num_channels, D, H, W)
        :return: output_tensor
        """
        output_tensor = torch.max(self.cSE(input_tensor), self.sSE(input_tensor))
        return output_tensor

class Modified3DUNet(nn.Module):
    def __init__(self, in_channels=4, n_classes=3, base_n_filter = 8):
        super(Modified3DUNet, self).__init__()
        self.in_channels = in_channels
        self.n_classes = n_classes
        self.base_n_filter = base_n_filter
        output_channels = base_n_filter
        filters = [4, 8, 16, 32, 64]

        self.lrelu = nn.LeakyReLU()
        self.dropout3d = nn.Dropout3d(p=0.6)
        self.upsacle = nn.Upsample(scale_factor=4, mode='nearest')
        self.softmax = nn.Softmax(dim=1)

        # Level 1 context pathway
        #       self.conv3d_c1_1 = nn.Conv3d(self.in_channels, self.base_n_filter, kernel_size=3, stride=1, padding=1, bias=False)
        self.conv3d_c1_1 = DepthwiseSeparableConv3d(self.in_channels, output_channels, kernel_size=3, padding=1, stride=1, dilation=1, groups=1, bias=False, kernels_per_layer=1)
        #       self.conv3d_c1_2 = nn.Conv3d(self.base_n_filter, self.base_n_filter, kernel_size=3, stride=1, padding=1, bias=False)
        self.conv3d_c1_2 = DepthwiseSeparableConv3d(self.base_n_filter, self.base_n_filter, kernel_size=3, padding=1, stride=1, dilation=1, groups=1, bias=False, kernels_per_layer=1)
        self.SE_c1 = ChannelSpatialSELayer3D(self.base_n_filter*1, norm='None')
        self.lrelu_conv_c1 = self.lrelu_conv(self.base_n_filter, self.base_n_filter)
        self.inorm3d_c1 = nn.InstanceNorm3d(self.base_n_filter)

        # Level 2 context pathway
        #       self.conv3d_c2 = nn.Conv3d(self.base_n_filter, self.base_n_filter*2, kernel_size=3, stride=2, padding=1, bias=False)
        self.conv3d_c2 = DepthwiseSeparableConv3d(self.base_n_filter, self.base_n_filter*2, kernel_size=3, padding=1, stride=2, dilation=1, groups=1, bias=False, kernels_per_layer=1)
        self.SE_c2 = ChannelSpatialSELayer3D(self.base_n_filter, norm='None')
        self.norm_lrelu_conv_c2 = self.norm_lrelu_conv(self.base_n_filter*2, self.base_n_filter*2)
        self.inorm3d_c2 = nn.InstanceNorm3d(self.base_n_filter*2)

        # Level 3 context pathway
        #       self.conv3d_c3 = nn.Conv3d(self.base_n_filter*2, self.base_n_filter*4, kernel_size=3, stride=2, padding=1, bias=False)
        self.conv3d_c3 = DepthwiseSeparableConv3d(self.base_n_filter*2, self.base_n_filter*4, kernel_size=3, padding=1, stride=1, dilation=1, groups=1, bias=False, kernels_per_layer=1)
        self.SE_c3 = ChannelSpatialSELayer3D(self.base_n_filter*2, norm='None')
        self.norm_lrelu_conv_c3 = self.norm_lrelu_conv(self.base_n_filter*4, self.base_n_filter*4)
        self.inorm3d_c3 = nn.InstanceNorm3d(self.base_n_filter*4)

        # Level 4 context pathway
        #       self.conv3d_c4 = nn.Conv3d(self.base_n_filter*4, self.base_n_filter*8, kernel_size=3, stride=2, padding=1, bias=False)
        self.conv3d_c4 = DepthwiseSeparableConv3d(self.base_n_filter*4, self.base_n_filter*8, kernel_size=2, padding=1, stride=1, dilation=1, groups=1, bias=False, kernels_per_layer=1)
        self.SE_c4 = ChannelSpatialSELayer3D(self.base_n_filter*4, norm='None')
        self.norm_lrelu_conv_c4 = self.norm_lrelu_conv(self.base_n_filter*8, self.base_n_filter*8)
        self.inorm3d_c4 = nn.InstanceNorm3d(self.base_n_filter*8)

        # Level 5 context pathway, level 0 localization pathway
        #       self.conv3d_c5 = nn.Conv3d(self.base_n_filter*8, self.base_n_filter*16, kernel_size=3, stride=2, padding=1, bias=False)
        self.conv3d_c5 = DepthwiseSeparableConv3d(self.base_n_filter*8, self.base_n_filter*16, kernel_size=3, padding=0, stride=2, dilation=1, groups=1, bias=False, kernels_per_layer=1)
        self.SE_c5 = ChannelSpatialSELayer3D(self.base_n_filter*8, norm='None')
        self.norm_lrelu_conv_c5 = self.norm_lrelu_conv(self.base_n_filter*16, self.base_n_filter*16)
        self.norm_lrelu_upscale_conv_norm_lrelu_l0 = self.norm_lrelu_upscale_conv_norm_lrelu(self.base_n_filter*16, self.base_n_filter*8)

        #       self.conv3d_l0 = nn.Conv3d(self.base_n_filter*8, self.base_n_filter*8, kernel_size = 1, stride=1, padding=0, bias=False)
        self.conv3d_l0 = DepthwiseSeparableConv3d(self.base_n_filter*8, self.base_n_filter*8, kernel_size=1, padding=1, stride=1, dilation=1, groups=1, bias=False, kernels_per_layer=1)
        self.SE_l0 = ChannelSpatialSELayer3D(self.base_n_filter*8, norm='None')
        self.inorm3d_l0 = nn.InstanceNorm3d(self.base_n_filter*8)
        
        # Level 1 localization pathway
        self.conv_norm_lrelu_l1 = self.conv_norm_lrelu(self.base_n_filter*16, self.base_n_filter*16)
        #       self.conv3d_l1 = nn.Conv3d(self.base_n_filter*16, self.base_n_filter*8, kernel_size=1, stride=1, padding=0, bias=False)
        self.conv3d_l1 = DepthwiseSeparableConv3d(self.base_n_filter*16, self.base_n_filter*8, kernel_size=1, padding=1, stride=1, dilation=1, groups=1, bias=False, kernels_per_layer=1)
        self.SE_l1 = ChannelSpatialSELayer3D(self.base_n_filter*16, norm='None')
        self.norm_lrelu_upscale_conv_norm_lrelu_l1 = self.norm_lrelu_upscale_conv_norm_lrelu(self.base_n_filter*8, self.base_n_filter*4)

        # Level 2 localization pathway
        self.conv_norm_lrelu_l2 = self.conv_norm_lrelu(self.base_n_filter*8, self.base_n_filter*8)
        #       self.conv3d_l2 = nn.Conv3d(self.base_n_filter*8, self.base_n_filter*4, kernel_size=1, stride=1, padding=0, bias=False)
        self.conv3d_l2 = DepthwiseSeparableConv3d(self.base_n_filter*8, self.base_n_filter*4, kernel_size=1, padding=1, stride=1, dilation=1, groups=1, bias=False, kernels_per_layer=1)
        self.SE_l2 = ChannelSpatialSELayer3D(self.base_n_filter*8, norm='None')
        self.norm_lrelu_upscale_conv_norm_lrelu_l2 = self.norm_lrelu_upscale_conv_norm_lrelu(self.base_n_filter*4, self.base_n_filter*2)

        # Level 3 localization pathway
        self.conv_norm_lrelu_l3 = self.conv_norm_lrelu(self.base_n_filter*4, self.base_n_filter*4)
        #       self.conv3d_l3 = nn.Conv3d(self.base_n_filter*4, self.base_n_filter*2, kernel_size=1, stride=1, padding=0, bias=False)
        self.conv3d_l3 = DepthwiseSeparableConv3d(self.base_n_filter*4, self.base_n_filter*2, kernel_size=1, padding=1, stride=1, dilation=1, groups=1, bias=False, kernels_per_layer=1)
        self.SE_l3 = ChannelSpatialSELayer3D(self.base_n_filter*4, norm='None')
        self.norm_lrelu_upscale_conv_norm_lrelu_l3 = self.norm_lrelu_upscale_conv_norm_lrelu(self.base_n_filter*2, self.base_n_filter)

        # Level 4 localization pathway
        self.conv_norm_lrelu_l4 = self.conv_norm_lrelu(self.base_n_filter*2, self.base_n_filter*2)
        #       self.conv3d_l4 = nn.Conv3d(self.base_n_filter*2, self.n_classes, kernel_size=1, stride=1, padding=0, bias=False)
        self.conv3d_l4 = DepthwiseSeparableConv3d(self.base_n_filter*2, self.n_classes, kernel_size=1, padding=1, stride=1, dilation=1, groups=1, bias=False, kernels_per_layer=1)
        self.SE_l4 = ChannelSpatialSELayer3D(self.base_n_filter*2, norm='None')
        
        self.ds2_1x1_conv3d = nn.Conv3d(self.base_n_filter*8, self.n_classes, kernel_size=1, stride=1, padding=0, bias=False)
#         self.ds2_1x1_conv3d =  DepthwiseSeparableConv3d(self.base_n_filter*8, self.n_classes, kernel_size=1, padding=1, stride=1, dilation=1, groups=1, bias=False, kernels_per_layer=1)
        self.ds3_1x1_conv3d = nn.Conv3d(self.base_n_filter*4, self.n_classes, kernel_size=1, stride=1, padding=0, bias=False)
#         self.ds3_1x1_conv3d = DepthwiseSeparableConv3d(self.base_n_filter*4, self.n_classes, kernel_size=1, padding=1, stride=1, dilation=1, groups=1, bias=False, kernels_per_layer=1)

        
    def conv_norm_lrelu(self, feat_in, feat_out):
        return nn.Sequential(
            nn.Conv3d(feat_in, feat_out, kernel_size=3, stride=1, padding=1, bias=False),
            nn.InstanceNorm3d(feat_out),
            nn.LeakyReLU())

    def norm_lrelu_conv(self, feat_in, feat_out):
        return nn.Sequential(
            nn.InstanceNorm3d(feat_in),
            nn.LeakyReLU(),
            nn.Conv3d(feat_in, feat_out, kernel_size=3, stride=1, padding=1, bias=False))

    def lrelu_conv(self, feat_in, feat_out):
        return nn.Sequential(
            nn.LeakyReLU(),
            nn.Conv3d(feat_in, feat_out, kernel_size=3, stride=1, padding=1, bias=False))

    def norm_lrelu_upscale_conv_norm_lrelu(self, feat_in, feat_out):
        return nn.Sequential(
            nn.InstanceNorm3d(feat_in),
            nn.LeakyReLU(),
            nn.Upsample(scale_factor=2, mode='nearest'),
            # should be feat_in*2 or feat_in
            nn.Conv3d(feat_in, feat_out, kernel_size=3, stride=1, padding=1, bias=False),
            nn.InstanceNorm3d(feat_out),
            nn.LeakyReLU())

    def forward(self, x):
        #  Level 1 context pathway
        out = self.conv3d_c1_1(x)
        residual_1 = out
        out = self.lrelu(out)
        out = self.conv3d_c1_2(out)
        out = self.SE_c1(out)
        out = self.dropout3d(out)
        out = self.lrelu_conv_c1(out)
        # Element Wise Summation
        out += residual_1
        context_1 = self.lrelu(out)
        out = self.inorm3d_c1(out)
        out = self.lrelu(out)

        # Level 2 context pathway
        out = self.conv3d_c2(out)
        out = self.SE_c2(out)
        residual_2 = out
        out = self.norm_lrelu_conv_c2(out)
        out = self.dropout3d(out)
        out = self.norm_lrelu_conv_c2(out)
        out += residual_2
        out = self.inorm3d_c2(out)
        out = self.lrelu(out)
        context_2 = out

        # Level 3 context pathway
        out = self.conv3d_c3(out)
        out = self.SE_c3(out)
        residual_3 = out
        out = self.norm_lrelu_conv_c3(out)
        out = self.dropout3d(out)
        out = self.norm_lrelu_conv_c3(out)
        out += residual_3
        out = self.inorm3d_c3(out)
        out = self.lrelu(out)
        context_3 = out

        # Level 4 context pathway
        out = self.conv3d_c4(out)
        out = self.SE_c4(out)
        residual_4 = out
        out = self.norm_lrelu_conv_c4(out)
        out = self.dropout3d(out)
        out = self.norm_lrelu_conv_c4(out)
        out += residual_4
        out = self.inorm3d_c4(out)
        out = self.lrelu(out)
        context_4 = out

        # Level 5 context pathway
        out = self.conv3d_c5(out)
        out = self.SE_c5(out)
        residual_5 = out
        out = self.norm_lrelu_conv_c5(out)
        out = self.dropout3d(out)
        out = self.norm_lrelu_conv_c5(out)
        out += residual_5
        out = self.norm_lrelu_upscale_conv_norm_lrelu_l0(out)

        out = self.conv3d_l0(out)
        out = self.inorm3d_l0(out)
        out = self.lrelu(out)
        

        if out.size()[2:] != context_4.size()[2:]:
            out = F.interpolate(out, size=context_4.size()[2:], mode='trilinear', align_corners=True)
        
        

        # Level 1 localization pathway
        out = torch.cat([out, context_4], dim=1) 
#         out = self.attn1(out, context_4)
        # out = torch.cat([out, context_4], dim=1)
        out = self.conv_norm_lrelu_l1(out)
        out = self.conv3d_l1(out)
        out = self.SE_l1(out)
        out = self.norm_lrelu_upscale_conv_norm_lrelu_l1(out)


        if out.size()[2:] != context_3.size()[2:]:
            out = F.interpolate(out, size=context_3.size()[2:], mode='trilinear', align_corners=True)


        # Level 2 localization pathway
        out = torch.cat([out, context_3], dim=1)
#         out = self.attn1(out, context_3)
        out = self.conv_norm_lrelu_l2(out)
        ds2 = out
        out = self.conv3d_l2(out)
        out = self.SE_l2(out)
        out = self.norm_lrelu_upscale_conv_norm_lrelu_l2(out)


        if out.size()[2:] != context_2.size()[2:]:
            out = F.interpolate(out, size=context_2.size()[2:], mode='trilinear', align_corners=True)

        # Level 3 localization pathway
        out = torch.cat([out, context_2], dim=1)
#         out = self.attn1(out, context_2)
        out = self.conv_norm_lrelu_l3(out)
        ds3 = out
        out = self.conv3d_l3(out)
        out = self.SE_l3(out)
        out = self.norm_lrelu_upscale_conv_norm_lrelu_l3(out)


        if out.size()[2:] != context_1.size()[2:]:
            out = F.interpolate(out, size=context_1.size()[2:], mode='trilinear', align_corners=True)

        # Level 4 localization pathway
        out = torch.cat([out, context_1], dim=1)
#         out = self.attn1(out, context_1)
        out = self.conv_norm_lrelu_l4(out)
        out = self.SE_l4(out)
        out_pred = self.conv3d_l4(out)

        ds2_1x1_conv = self.ds2_1x1_conv3d(ds2)
        ds1_ds2_sum_upscale = self.upsacle(ds2_1x1_conv)
        ds3_1x1_conv = self.ds3_1x1_conv3d(ds3)
        if ds1_ds2_sum_upscale.size()[2:] != ds3_1x1_conv.size()[2:]:
            # Resize ds1_ds2_sum_upscale to match the spatial dimensions of ds3_1x1_conv
            ds1_ds2_sum_upscale = F.interpolate(ds1_ds2_sum_upscale, size=ds3_1x1_conv.size()[2:], mode='trilinear', align_corners=True)
        ds1_ds2_sum_upscale_ds3_sum = ds1_ds2_sum_upscale + ds3_1x1_conv
        ds1_ds2_sum_upscale_ds3_sum_upscale = self.upsacle(ds1_ds2_sum_upscale_ds3_sum)


        if out_pred.size()[2:] != ds1_ds2_sum_upscale_ds3_sum_upscale.size()[2:]:
            # Resize ds1_ds2_sum_upscale to match the spatial dimensions of ds3_1x1_conv
            out_pred = F.interpolate(out_pred, size=ds1_ds2_sum_upscale_ds3_sum_upscale.size()[2:], mode='trilinear', align_corners=True)

        out = out_pred + ds1_ds2_sum_upscale_ds3_sum_upscale
        return out
        #seg_layer = out
        #out = out.permute(0, 2, 3, 4, 1).contiguous().view(-1, self.n_classes)
        #out = self.softmax(out)

        #return out, seg_layer

Could you add the missing code pieces to make your code snippet executable so that we could reproduce the issue?

Oh yeah, i forgot the depthwise block

class DepthwiseSeparableConv3d(nn.Module):
    def __init__(self, in_channels, output_channels, kernel_size, padding, stride, dilation=1, groups=1, bias=False, kernels_per_layer=1):
        super(DepthwiseSeparableConv3d, self).__init__()
        self.depthwise = nn.Conv3d(in_channels, in_channels * kernels_per_layer, kernel_size, padding=padding,
                                   stride=stride, dilation=dilation, groups=in_channels, bias=bias)
        self.pointwise = nn.Conv3d(in_channels * kernels_per_layer, output_channels, 1, padding=0,
                                   stride=stride, dilation=dilation, groups=groups, bias=bias)

    def forward(self, x):
        x = self.depthwise(x)
        x = self.pointwise(x)
        return x

Thank you in advance @ptrblck actually i already solved it, i just trial and error the input shape haha… But can i ask you, how can i improve the model? I’m stuck on the performance and havent reach the dice goal. Do you have any reference code?