I didn't manage to freeze some layers

Hi,

I’m trying to train a new model to do segmentation and another task (separate between two images superposition).

The segmentation part of the model is actually the deeplabv3-resnet101 model, pre-trained.
In the long run I should train also this part, and to make some skip connections too. But in the first step I took this model as is, and froze it’s layers by: require_grad=False.

Even though, it’s seems that this layers had changed, and the segmentation output of my model is absolutely different comparing to the built-in model, and comparing to my model without loading the trained weights.

I would appreciate your help!

Did you call model.eval() when you’ve compared the outputs?
If so, could you post a minimal code snippet which would show that the model parameters get updates while they are frozen?

Yep.

the code snipest:

the model:


import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.model_zoo as model_zoo
import torchvision.models as models
from torchvision.models import resnet 
from torchvision.models._utils import IntermediateLayerGetter

class deeplab_v3_separation_dev(nn.Module):    
    def __init__(self):
        super(deeplab_v3_separation_dev, self).__init__()
        deepLabV3ResNet101 = models.segmentation.deeplabv3_resnet101(pretrained=True,aux_loss=None)
        deepLabV3ResNet101Layers = list(deepLabV3ResNet101.children())
        ClassifierLayers = deepLabV3ResNet101.classifier
        ASSP_features = ClassifierLayers[0]
        ResNet101_features = list(deepLabV3ResNet101.backbone.children())
        semantic_layers = list(ClassifierLayers)
        

        self.ResNet101_features = nn.Sequential(*ResNet101_features)
        self.ASSP_features = ASSP_features
        
        self.Semantic_head = nn.Sequential(*semantic_layers[1:])  
        self.SFTMX = nn.Sigmoid()
        # self.Background_convs = nn.Sequential(nn.Conv2d(280, 3, kernel_size=(1, 1), stride=(1, 1), bias=False))#,
                                              # nn.BatchNorm2d(3, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),
                                              # nn.ReLU(),
                                              # nn.Dropout(p=0.5, inplace=False))

        # self.Separation_convs = nn.Sequential(nn.Conv2d(283, 3, kernel_size=(1, 1), stride=(1, 1), bias=False))#,
                                              # nn.BatchNorm2d(3, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),
                                              # nn.ReLU(),
                                              # nn.Dropout(p=0.5, inplace=False))
        self.Background_convs = nn.Sequential(nn.Conv2d(24, 24, 1),
                                              nn.BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),
                                              nn.ReLU(),
                                              nn.Conv2d(24, 3, 1))
        self.Separation_convs = nn.Sequential(nn.Conv2d(27, 27, 1),
                                              nn.BatchNorm2d(27, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),
                                              nn.ReLU(),
                                              nn.Conv2d(27, 3, 1))
        
    def forward(self, x):
        input_shape = x.shape[-2:]
        x1 =self.ResNet101_features(x)
        x2 = self.ASSP_features(x1)
        y1 = self.Semantic_head(x2)
        y1 = self.SFTMX(y1)
        y1 = F.interpolate(y1, size=input_shape, mode='bilinear', align_corners=False)
        x2 = F.interpolate(x2, size=input_shape, mode='bilinear', align_corners=False)
        y2 = self.Background_convs(torch.cat((y1,x),1)) #x2,
        y3 = self.Separation_convs(torch.cat((y1,y2,x),1)) #x2,       
        return y1, y2, y3
    
 

The freezing layers code:

    i=0
    for param in net.parameters():
        if i<340 :
            param.requires_grad = False
        i+=1 

and the training script:

def train_net(net,
              device,
              epochs=5,
              batch_size=2,
              lr=0.05,
              val_percent=0,
              save_cp=True,
              img_scale=0.25):

    dataset = BasicDataset(dir_img, dir_mask, img_scale)
    n_val = int(len(dataset) * val_percent)
    n_train = len(dataset) - n_val
    train, val = random_split(dataset, [n_train, n_val])
    train_loader = DataLoader(train, batch_size=batch_size, shuffle=True, num_workers=8, pin_memory=True)
    val_loader = DataLoader(val, batch_size=batch_size, shuffle=False, num_workers=8, pin_memory=True, drop_last=True)

    deepLabV3ResNet101 = models.segmentation.deeplabv3_resnet101(pretrained=True)
    writer = SummaryWriter(comment=f'LR_{lr}_BS_{batch_size}_SCALE_{img_scale}')
    global_step = 0
    optimizer = optim.Adam(net.parameters(), lr=lr, betas=(0.7, 0.999), eps=1e-08, weight_decay=1e-3, amsgrad=False)
    scheduler = optim.lr_scheduler.StepLR(optimizer,  step_size=5, gamma=(0.3))
    L1_criterion = nn.L1Loss(reduction='none')    
    semantic_criterion = nn.BCELoss(reduction='none')#BCEWithLogitsLoss
    for epoch in range(epochs):
        net.train()
        with tqdm(total=n_train, desc=f'Epoch {epoch + 1}/{epochs}', unit='img') as pbar:
            for batch in train_loader:
                mixedImgs = batch['mixedImage']
                if mixedImgs.size()[0]==1:
                    continue                   
                BackgroundImage = batch['BackgroundImage']
                ReflectionImage = batch['ReflectionImage']
                BackgtoundSemantics = batch['BackgtoundSemantics']
                mixedImgs = mixedImgs.to(device=device) #, dtype=torch.float32)
                BackgroundImage = BackgroundImage.to(device=device) #, dtype=torch.float32)*255
                ReflectionImage = ReflectionImage.to(device=device) #, dtype=torch.float32)
                BackgtoundSemantics =  BackgtoundSemantics.to(device=device , dtype=torch.float32) # , dtype=torch.long            
                semantic_pred, BG_img_pred, R_img_pred = net.forward(mixedImgs) # , BG_img_pred, R_img_pred
                loss1 = semantic_criterion(semantic_pred,BackgtoundSemantics ) #BackgtoundSemantics[:,:,:,:]   
                var1 = torch.var(loss1)            
                loss2 =  L1_criterion(BG_img_pred, BackgroundImage)
                # # # # loss += 0.0003*LA.norm(cv2.Canny(BG_img_pred,100,200) - cv2.Canny(BackgroundImage,100,200))

                loss3 = 0.6*(1 - ssim.ssim(BG_img_pred, BackgroundImage))
                loss4 = 0.8*L1_criterion(R_img_pred,ReflectionImage)# torch.sum(torch.abs(R_img_pred-ReflectionImage),1)
                var2 = torch.var(loss4) #dim=0, keepdim=True, unbiased=False

                Loss1 = loss1/(2*var1.square())
                Loss2 = (loss2+loss3+loss4) /(2*var2.square()) 

                # # epoch_loss += loss.item()
                loss = 0*torch.mean(Loss2)+0*torch.mean(Loss1)
            
                writer.add_scalar('Loss/train', loss.item(), global_step)
                pbar.set_postfix(**{'loss (batch)': loss.item()})

                optimizer.zero_grad()
                loss.backward()

                nn.utils.clip_grad_value_(net.parameters(), 0.1)
                optimizer.step()

                pbar.update(mixedImgs.shape[0])
                global_step += 1
                if global_step % (len(dataset) // (10 * batch_size)) == 0:
                    scheduler.step()
                    writer.add_scalar('learning_rate', optimizer.param_groups[0]['lr'], global_step)

Your code snippet is unfortunately not executable, but using random inputs shows that the frozen parameters are not updated:

# freeze some layers and store reference params
net = deeplab_v3_separation_dev()
ref_params_freeze = []
ref_params_update = []
i=0
for param in net.parameters():
    if i<340 :
        param.requires_grad = False
        ref_params_freeze.append(param.clone())
    else:
        ref_params_update.append(param.clone())
    i+=1 
    

# setup
optimizer = torch.optim.Adam(net.parameters(), lr=1e-3, betas=(0.7, 0.999), eps=1e-08, weight_decay=1e-3, amsgrad=False)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer,  step_size=5, gamma=(0.3))
L1_criterion = nn.L1Loss(reduction='none')    
semantic_criterion = nn.BCELoss(reduction='none')#BCEWithLogitsLoss

net.train()

mixedImgs = torch.randn(2, 3 ,224, 224)
BackgtoundSemantics = torch.randint(0, 2, (2, 21, 224, 224)).float()
BackgroundImage = torch.randint(0, 2, (2, 3, 224, 224)).float()
ReflectionImage = torch.randint(0, 2, (2, 3, 224, 224)).float()

# updates
semantic_pred, BG_img_pred, R_img_pred = net(mixedImgs)
loss1 = semantic_criterion(semantic_pred,BackgtoundSemantics)
var1 = torch.var(loss1)
loss2 =  L1_criterion(BG_img_pred, BackgroundImage)
loss3 = 0. # undefined class

loss4 = 0.8*L1_criterion(R_img_pred,ReflectionImage)
var2 = torch.var(loss4)
Loss1 = loss1/(2*var1.square())
Loss2 = (loss2+loss3+loss4) /(2*var2.square()) 
loss = 0*torch.mean(Loss2)+0*torch.mean(Loss1)

# step
optimizer.zero_grad()
loss.backward()
nn.utils.clip_grad_value_(net.parameters(), 0.1)
optimizer.step()

# check param updates
i = 0
for param in net.parameters():
    if i < 340:
        print('frozen param abs().max() err {}'.format(
            (ref_params_freeze[i] - param).abs().max()))
    else:
        print('trainable param abs().max() err {}'.format(
            (ref_params_update[i-340] - param).abs().max()))
    i += 1

I’m sorry, I tryed to zero the loss to check something, and forgot to change it again.

how can i freeze the layers?

בתאריך יום ו׳, 25 ביוני 2021, 9:07, מאת ptrblck via PyTorch Forums ‏<noreply@discuss.pytorch.org>:

You are already freezing the first 340 parameters, so I’m unsure what exactly is not working at the moment. Could you check my code snippet and see what is not working for you?

Your code is working well. I still not understand what is the problem.

I’m just tryed to zero only the segmentation loss, and the parameters don’t updated.
Now it’s seems to me that the segmentation loss had updated the frozen parameters.
I’m checking again without zero that loss to be sure that it’s really the problem.