Efficient Net on MRI reconstruction data

Hi experts, am working with the efficient net and I wish to modify it to solve the problem of getting the high detailed (PSNR) MRI data from low detail MRI images. For this I am trying to get the efficient net working with it. I am in initial stages so I have just removed the FC layer used in classification to upsampling layers in Unet. Is this sort of correct?

The original code is for classification using Efficient Net is as follows -

  # Head
    in_channels = block_args.output_filters  # output of final block
    out_channels = round_filters(1280, self._global_params)
    Conv2d = get_same_padding_conv2d(image_size=image_size)
    self._conv_head = Conv2d(in_channels, out_channels, kernel_size=1, bias=False)
    self._bn1 = nn.BatchNorm2d(num_features=out_channels, momentum=bn_mom, eps=bn_eps)

    # Final linear layer
    self._avg_pooling = nn.AdaptiveAvgPool2d(1)
    self._dropout = nn.Dropout(self._global_params.dropout_rate)
    self._fc = nn.Linear(out_channels, self._global_params.num_classes)
    self._swish = MemoryEfficientSwish()

My mods ( line 234 )

    in_channels = block_args.output_filters  # output of final block
    out_channels = round_filters(320, self._global_params)
    Conv2d = get_same_padding_conv2d(image_size=image_size)
    
    self._conv_head = Conv2d(in_channels, out_channels, kernel_size=1, bias=False)
    self._bn1 = nn.BatchNorm2d(num_features=out_channels, momentum=bn_mom, eps=bn_eps)

    # Final layer
    self._avg_pooling = nn.AdaptiveAvgPool2d(1)
    #self._dropout = nn.Dropout(self._global_params.dropout_rate)
    #self._fc = nn.Conv2d(out_channels, out_channels,kernel_size=1)
    
    self.upconv_1 = UpConv2d(out_channels, out_channels //2, 4 , 1) # 320 - 160
    out_channels //= 2
    self.upconv_2 = UpConv2d(out_channels, out_channels //2, 4 , 1) #160 - 80
    out_channels //= 2
    self.upconv_3 = UpConv2d(out_channels, out_channels //2, 4 , 1) #80 - 40 
    out_channels //= 2
    self.upconv_4 = UpConv2d(out_channels, out_channels //2, 4 , 1) #40 - 20
    out_channels //= 2
    #self.upconv_5 = UpConv2d(out_channels, 1, 2 , 1) #20 - 10
    self.final = finalconv(out_channels,1,320,1) #cropping size to 320
    self._swish = MemoryEfficientSwish()

Here I am sort of upsampling so that the image size is according to my input.

The above repo has the code for my modifications.

efficientnetdebug.ipynb - this notebook uses torch summary to get the summary of the the model.