UNet model class with freezing parameter

Hi to all, I am working with microscopy images. I have implemented a Unet model for image segmentation and I have trained it in 1800 images/labels. Now I want to use transfer learning to segmentate new sample images . I have hard coded and I have made a function that freezes first all the model with this code:

    for param in model.module.parameters(): 
        param.requires_grad=False

and then based the number of the block I unfreeze the layer that I want to train.

I was thinking if there is a easier way to unfreeze the layers (also the ones inside the sequentials), I was thinking of passing inside my model class a parameter that could freeze specific layers based on a list or tuple.
Thanks in advance.

def double_conv(in_c,out_c):
    conv=nn.Sequential(
            nn.Conv2d(in_c,out_c,
                      kernel_size= 3 , stride=1 ,
                      padding= 1,bias=False),
            nn.ReLU(inplace=True),
            nn.BatchNorm2d(out_c),
            nn.Conv2d(out_c,out_c,
                     kernel_size= 3 , stride=1 , 
                     padding= 1,bias=False),
            nn.ReLU(inplace=True),
            nn.BatchNorm2d(out_c),
            nn.Dropout(0.1,inplace=True) # I used it in the end of the block. Hamed in Keras in between the two conv2d
           )
    
         
    return conv    

class UNet_(nn.Module):
    def __init__(self):
        
        super(UNet_,self).__init__()
        self.down_conv_1=double_conv(1,16) 
        self.max1=nn.MaxPool2d(kernel_size=2)
        self.down_conv_2=double_conv(16,32)
        self.max2=nn.MaxPool2d(kernel_size=2)
        self.down_conv_3=double_conv(32,64)
        self.max3=nn.MaxPool2d(kernel_size=2)
        self.down_conv_4=double_conv(64,128)
        self.max4=nn.MaxPool2d(kernel_size=2)
        self.down_conv_5=double_conv(128,256)
        self.max5=nn.MaxPool2d(kernel_size=2)
        seld.down_conv_6=double_conv(256,512)
        
        # decoder 
#------------------------------------------------
        self.up_trans_1=nn.ConvTranspose2d( #1
             in_channels=256,
             out_channels=128,
         kernel_size=2,
         stride=2)
        
        self.up_conv_1=double_conv(256,128)
               
        #---------------------------------------------------------------------       
        self.up_trans_2=nn.ConvTranspose2d( #2
             in_channels=128,
             out_channels=64,
         kernel_size=2,
         stride=2)
        
        self.up_conv_2=double_conv(128,64)
#--------------------------------------------------------------------        
        self.up_trans_3=nn.ConvTranspose2d( #3
             in_channels=64,
             out_channels=32,
         kernel_size=2,
         stride=2,
         )
        self.up_conv_3=double_conv(64,32)
#------------------------------------------------------------------        
        self.up_trans_4=nn.ConvTranspose2d( #4
             in_channels=32,
             out_channels=16,
         kernel_size=2,
         stride=2,
         )
        self.up_conv_4=double_conv(32,16)
        
        
        self.up_trans_5=nn.ConvTranspose2d(
            in_channels=16,
            out_channels=1,kernel_size=2,strid2=2)
        
        self.up_conv_5=double_conv(16,1)
        
#------------ out channel       
    
        self.out=nn.Sequential(
            nn.Conv2d(in_channels=16
                      ,out_channels=1,kernel_size=1),
                       nn.Sigmoid())
        
    
    def forward(self,image,#freezing=False): 
        
        
        #encoder part (what goes up mast comes down!!!)
        
        x1=self.down_conv_1(image)#  Input channels / Output channels (1,16) 
        m1=self.max1(x1)

        x2=self.down_conv_2(m1)# Input channels / Output channels (16,32)
        m2=self.max2(x2)
        
        x3=self.down_conv_3(m2)#Input channels / Output channels (32,64)
        m3=self.max3(x3)
        
        x4=self.down_conv_4(m3)# Input channels / Output channels (64,128)
        m4=self.max4(x4)
        
        x5=self.down_conv_5(m4)# Input channels / Output channels (128,256)
                
        #decoder part (to the stars!!! )
    
        up1=self.up_trans_1(x5) #Input channels / Output channels (256,128)
        
        sc_1=torch.cat([up1,x4],1) # 
        up1=self.up_conv_1(sc_1)
                
        up2=self.up_trans_2(up1)
        sc_2=torch.cat([up2,x3],1)
        up2=self.up_conv_2(sc_2)

        up3=self.up_trans_3(up2)
        sc_3=torch.cat([up3,x2],1)
        up3=self.up_conv_3(sc_3)
        
        up4=self.up_trans_4(up3)
        sc_4=torch.cat([up4,x1],1)
        up4=self.up_conv_4(sc_4) 
         #output layer___________________________________________________
        out=self.out(up4)

        return out