How to increase depth of a model i.e add more layers in this code. please help me

import torch
import torch.nn as nn
import torch.nn.functional as F


class ConvBlock(nn.Module):
    def __init__(self, in_planes, out_planes, stride=1):
        super(ConvBlock, self).__init__()
        self.conv = nn.Conv2d(in_planes, out_planes, 3, stride=stride, padding=1, bias=False)
        self.bn = nn.BatchNorm2d(out_planes)
        self.act = nn.PReLU(out_planes)

    def forward(self, input):
        output = self.act(self.bn(self.conv(input)))
        return output


class DilatedParallelConvBlockD2(nn.Module):
    def __init__(self, in_planes, out_planes):
        super(DilatedParallelConvBlockD2, self).__init__()
        self.conv0 = nn.Conv2d(in_planes, out_planes, 1, stride=1, padding=0, dilation=1, groups=1, bias=False)
        self.conv1 = nn.Conv2d(out_planes, out_planes, 3, stride=1, padding=1, dilation=1, groups=out_planes, bias=False)
        self.conv2 = nn.Conv2d(out_planes, out_planes, 3, stride=1, padding=2, dilation=2, groups=out_planes, bias=False)
        self.bn = nn.BatchNorm2d(out_planes)

    def forward(self, input):
        output = self.conv0(input)
        d1 = self.conv1(output)
        d2 = self.conv2(output)
        output = d1 + d2
        output = self.bn(output)
        return output


class DilatedParallelConvBlock(nn.Module):
    def __init__(self, in_planes, out_planes, stride=1):
        super(DilatedParallelConvBlock, self).__init__()
        assert out_planes % 4 == 0
        inter_planes = out_planes // 4
        self.conv1x1_down = nn.Conv2d(in_planes, inter_planes, 1, padding=0, groups=1, bias=False)
        self.conv1 = nn.Conv2d(inter_planes, inter_planes, 3, stride=stride, padding=1, dilation=1, groups=inter_planes, bias=False)
        self.conv2 = nn.Conv2d(inter_planes, inter_planes, 3, stride=stride, padding=2, dilation=2, groups=inter_planes, bias=False)
        self.conv3 = nn.Conv2d(inter_planes, inter_planes, 3, stride=stride, padding=4, dilation=4, groups=inter_planes, bias=False)
        self.conv4 = nn.Conv2d(inter_planes, inter_planes, 3, stride=stride, padding=8, dilation=8, groups=inter_planes, bias=False)
        self.pool = nn.AvgPool2d(3, stride=stride, padding=1)
        self.conv1x1_fuse = nn.Conv2d(out_planes, out_planes, 1, padding=0, groups=4, bias=False)
        self.attention = nn.Conv2d(out_planes, 4, 1, padding=0, groups=4, bias=False)
        self.bn = nn.BatchNorm2d(out_planes)
        self.act = nn.PReLU(out_planes)


    def forward(self, input):
        output = self.conv1x1_down(input)
        d1 = self.conv1(output)
        d2 = self.conv2(output)
        d3 = self.conv3(output)
        d4 = self.conv4(output)
        p = self.pool(output)
        d1 = d1 + p
        d2 = d1 + d2
        d3 = d2 + d3
        d4 = d3 + d4
        att = torch.sigmoid(self.attention(torch.cat([d1, d2, d3, d4], 1)))
        d1 = d1 + d1 * att[:, 0].unsqueeze(1)
        d2 = d2 + d2 * att[:, 1].unsqueeze(1)
        d3 = d3 + d3 * att[:, 2].unsqueeze(1)
        d4 = d4 + d4 * att[:, 3].unsqueeze(1)
        output = self.conv1x1_fuse(torch.cat([d1, d2, d3, d4], 1))
        output = self.act(self.bn(output))

        return output


class DownsamplerBlock(nn.Module):
    def __init__(self, in_planes, out_planes, stride=2):
        super(DownsamplerBlock, self).__init__()
        self.conv0 = nn.Conv2d(in_planes, out_planes, 1, stride=1, padding=0, groups=1, bias=False)
        self.conv1 = nn.Conv2d(out_planes, out_planes, 5, stride=stride, padding=2, groups=out_planes, bias=False)
        self.bn = nn.BatchNorm2d(out_planes)
        self.act = nn.PReLU(out_planes)

    def forward(self, input):
        output = self.conv1(self.conv0(input))
        output = self.act(self.bn(output))
        return output


def split(x):
    c = int(x.size()[1])
    c1 = round(c // 2)
    x1 = x[:, :c1, :, :].contiguous()
    x2 = x[:, c1:, :, :].contiguous()
    return x1, x2


class MiniSeg(nn.Module):
    def __init__(self, classes=2, P1=2, P2=3, P3=8, P4=6, aux=True):
        super(MiniSeg, self).__init__()
        self.D1 = int(P1/2)
        self.D2 = int(P2/2)
        self.D3 = int(P3/2)
        self.D4 = int(P4/2)
        self.aux = aux

        self.long1 = DownsamplerBlock(3, 8, stride=2)
        self.down1 = ConvBlock(3, 8, stride=2)
        self.level1 = nn.ModuleList()
        self.level1_long = nn.ModuleList()
        for i in range(0, P1):
            self.level1.append(ConvBlock(8, 8))
        for i in range(0, self.D1):
            self.level1_long.append(DownsamplerBlock(8, 8, stride=1))

        self.cat1 = nn.Sequential(
                    nn.Conv2d(16, 16, 1, stride=1, padding=0, groups=1, bias=False),
                    nn.BatchNorm2d(16))

        self.long2 = DownsamplerBlock(8, 24, stride=2)
        self.down2 = DilatedParallelConvBlock(8, 24, stride=2)
        self.level2 = nn.ModuleList()
        self.level2_long = nn.ModuleList()
        for i in range(0, P2):
            self.level2.append(DilatedParallelConvBlock(24, 24))
        for i in range(0, self.D2):
            self.level2_long.append(DownsamplerBlock(24, 24, stride=1))

        self.cat2 = nn.Sequential(
                    nn.Conv2d(48, 48, 1, stride=1, padding=0, groups=1, bias=False),
                    nn.BatchNorm2d(48))

        self.long3 = DownsamplerBlock(24, 32, stride=2)
        self.down3 = DilatedParallelConvBlock(24, 32, stride=2)
        self.level3 = nn.ModuleList()
        self.level3_long = nn.ModuleList()
        for i in range(0, P3):
            self.level3.append(DilatedParallelConvBlock(32, 32))
        for i in range(0, self.D3):
            self.level3_long.append(DownsamplerBlock(32, 32, stride=1))

        self.cat3 = nn.Sequential(
                    nn.Conv2d(64, 64, 1, stride=1, padding=0, groups=1, bias=False),
                    nn.BatchNorm2d(64))

        self.long4 = DownsamplerBlock(32, 64, stride=2)
        self.down4 = DilatedParallelConvBlock(32, 64, stride=2)
        self.level4 = nn.ModuleList()
        self.level4_long = nn.ModuleList()
        for i in range(0, P4):
            self.level4.append(DilatedParallelConvBlock(64, 64))
        for i in range(0, self.D4):
            self.level4_long.append(DownsamplerBlock(64, 64, stride=1))

        self.up4_conv4 = nn.Conv2d(64, 64, 1, stride=1, padding=0)
        self.up4_bn4 = nn.BatchNorm2d(64)
        self.up4_act = nn.PReLU(64)

        self.up3_conv4 = DilatedParallelConvBlockD2(64, 32)
        self.up3_conv3 = nn.Conv2d(32, 32, 1, stride=1, padding=0)
        self.up3_bn3 = nn.BatchNorm2d(32)
        self.up3_act = nn.PReLU(32)

        self.up2_conv3 = DilatedParallelConvBlockD2(32, 24)
        self.up2_conv2 = nn.Conv2d(24, 24, 1, stride=1, padding=0)
        self.up2_bn2 = nn.BatchNorm2d(24)
        self.up2_act = nn.PReLU(24)

        self.up1_conv2 = DilatedParallelConvBlockD2(24, 8)
        self.up1_conv1 = nn.Conv2d(8, 8, 1, stride=1, padding=0)
        self.up1_bn1 = nn.BatchNorm2d(8)
        self.up1_act = nn.PReLU(8)

        if self.aux:
            self.pred4 = nn.Sequential(nn.Dropout2d(0.01, False), nn.Conv2d(64, classes, 1, stride=1, padding=0))
            self.pred3 = nn.Sequential(nn.Dropout2d(0.01, False), nn.Conv2d(32, classes, 1, stride=1, padding=0))
            self.pred2 = nn.Sequential(nn.Dropout2d(0.01, False), nn.Conv2d(24, classes, 1, stride=1, padding=0))
        self.pred1 = nn.Sequential(nn.Dropout2d(0.01, False), nn.Conv2d(8, classes, 1, stride=1, padding=0))

    def forward(self, input):
        long1 = self.long1(input)
        output1 = self.down1(input)
        output1_add = output1 + long1
        for i, layer in enumerate(self.level1):
            if i < self.D1:
                output1 = layer(output1_add) + output1
                long1 = self.level1_long[i](output1_add) + long1
                output1_add = output1 + long1
            else:
                output1 = layer(output1_add) + output1
                output1_add = output1 + long1

        output1_cat = self.cat1(torch.cat([long1, output1], 1))
        output1_l, output1_r = split(output1_cat)

        long2 = self.long2(output1_l + long1)
        output2 = self.down2(output1_r + output1)
        output2_add = output2 + long2
        for i, layer in enumerate(self.level2):
            if i < self.D2:
                output2 = layer(output2_add) + output2
                long2 = self.level2_long[i](output2_add) + long2
                output2_add = output2 + long2
            else:
                output2 = layer(output2_add) + output2
                output2_add = output2 + long2

        output2_cat = self.cat2(torch.cat([long2, output2], 1))
        output2_l, output2_r = split(output2_cat)

        long3 = self.long3(output2_l + long2)
        output3 = self.down3(output2_r + output2)
        output3_add = output3 + long3
        for i, layer in enumerate(self.level3):
            if i < self.D3:
                output3 = layer(output3_add) + output3
                long3 = self.level3_long[i](output3_add) + long3
                output3_add = output3 + long3
            else:
                output3 = layer(output3_add) + output3
                output3_add = output3 + long3

        output3_cat = self.cat3(torch.cat([long3, output3], 1))
        output3_l, output3_r = split(output3_cat)

        long4 = self.long4(output3_l + long3)
        output4 = self.down4(output3_r + output3)
        output4_add = output4 + long4
        for i, layer in enumerate(self.level4):
            if i < self.D4:
                output4 = layer(output4_add) + output4
                long4 = self.level4_long[i](output4_add) + long4
                output4_add = output4 + long4
            else:
                output4 = layer(output4_add) + output4
                output4_add = output4 + long4

        up4_conv4 = self.up4_bn4(self.up4_conv4(output4))
        up4 = self.up4_act(up4_conv4)

        up4 = F.interpolate(up4, output3.size()[2:], mode='bilinear', align_corners=False)
        up3_conv4 = self.up3_conv4(up4)
        up3_conv3 = self.up3_bn3(self.up3_conv3(output3))
        up3 = self.up3_act(up3_conv4 + up3_conv3)

        up3 = F.interpolate(up3, output2.size()[2:], mode='bilinear', align_corners=False)
        up2_conv3 = self.up2_conv3(up3)
        up2_conv2 = self.up2_bn2(self.up2_conv2(output2))
        up2 = self.up2_act(up2_conv3 + up2_conv2)

        up2 = F.interpolate(up2, output1.size()[2:], mode='bilinear', align_corners=False)
        up1_conv2 = self.up1_conv2(up2)
        up1_conv1 = self.up1_bn1(self.up1_conv1(output1))
        up1 = self.up1_act(up1_conv2 + up1_conv1)

        if self.aux:
            pred4 = F.interpolate(self.pred4(up4), input.size()[2:], mode='bilinear', align_corners=False)
            pred3 = F.interpolate(self.pred3(up3), input.size()[2:], mode='bilinear', align_corners=False)
            pred2 = F.interpolate(self.pred2(up2), input.size()[2:], mode='bilinear', align_corners=False)
        pred1 = F.interpolate(self.pred1(up1), input.size()[2:], mode='bilinear', align_corners=False)

        if self.aux:
            return (pred1, pred2, pred3, pred4, )
        else:
            return (pred1, )



@inproceedings{qiu2021miniseg,
  title={Mini{S}eg: An Extremely Minimum Network for Efficient {COVID}-19 Segmentation},
  author={Qiu, Yu and Liu, Yun and Li, Shijie and Xu, Jing},
  booktitle={AAAI Conference on Artificial Intelligence},
  year={2021}
}