Resnet18 based autoencoder

I want to make a resnet18 based autoencoder for a binary classification problem. I have taken a Unet decoder from timm segmentation library.

Currently I am facing the following problems:

-I want to take the output from resnet 18 before the last average pool layer and send it to the decoder. I will use the decoder output and calculate a L1 loss comparing it with the input image.
-I want to remove only the last linear layer and replace it with linear layer for binary classification as my problem requires a binary classification. Then I want to use this output for calculating BCE loss.

Finally, these two losses will be added in my final model. What is the best way to write the model in code?

Here are my resnet block and decoder block. But how can I use the model like I want in the training loop?

class pretrainedModelBlock(nn.Module):
  def __init__(self, model_name = params['model'], classes = params["num_classes"], addDecoder = False):
    super(pretrainedModelBlock, self).__init__()
    self.model_name = model_name
    self.classes = classes
    self.addDecoder = addDecoder
    self.custom_model = self.create_model()

  def create_model(self):
    pretrained_model = getattr(models, self.model_name)(pretrained = True)
    # for name, layer in pretrained_model.named_modules():
      # if(name == "fc"):
      #   print(name, layer)
    
    if self.addDecoder == False:
      pretrained_model.fc = nn.Identity()
      pretrained_model.fc = nn.Linear(in_features = 512, out_features = self.classes, bias = True)
    # else:
    #   pretrained_model.avgpool = nn.Identity()
    #   pretrained_model.fc = nn.Identity()

    return pretrained_model

  def forward(self, x):
    resnet_features = self.custom_model(x)

import torch.nn.functional as F

from segmentation_models_pytorch.base import modules as md


class DecoderBlock(nn.Module):
    def __init__(
        self,
        in_channels,
        skip_channels,
        out_channels,
        use_batchnorm=True,
        attention_type=None,
    ):
        super().__init__()
        self.conv1 = md.Conv2dReLU(
            in_channels + skip_channels,
            out_channels,
            kernel_size=3,
            padding=1,
            use_batchnorm=use_batchnorm,
        )
        self.attention1 = md.Attention(attention_type, in_channels=in_channels + skip_channels)
        self.conv2 = md.Conv2dReLU(
            out_channels,
            out_channels,
            kernel_size=3,
            padding=1,
            use_batchnorm=use_batchnorm,
        )
        self.attention2 = md.Attention(attention_type, in_channels=out_channels)

    def forward(self, x, skip=None):
        x = F.interpolate(x, scale_factor=2, mode="nearest")
        if skip is not None:
            x = torch.cat([x, skip], dim=1)
            x = self.attention1(x)
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.attention2(x)
        return x